diff --git a/forum/models/comments.py b/forum/models/comments.py index 69002ed1..44db3f87 100644 --- a/forum/models/comments.py +++ b/forum/models/comments.py @@ -128,11 +128,14 @@ def insert( if parent_id: comment_data["parent_id"] = ObjectId(parent_id) + comment_data["endorsement"] = None + result = self._collection.insert_one(comment_data) if parent_id: self.update_child_count_in_parent_comment(parent_id, 1) if comment_thread_id: self.update_comment_count_in_comment_thread(comment_thread_id, 1) + self.update_sk(str(result.inserted_id), parent_id) if result: get_handler_by_name("comment_inserted").send( sender=self.__class__, comment_id=str(result.inserted_id) @@ -163,6 +166,7 @@ def update( editing_user_id: Optional[str] = None, edit_reason_code: Optional[str] = None, endorsement_user_id: Optional[str] = None, + sk: Optional[str] = None, ) -> int: """ Updates a comment document in the database. @@ -205,6 +209,7 @@ def update( ("child_count", child_count), ("depth", depth), ("closed", closed), + ("sk", sk), ] update_data: dict[str, Any] = { field: value for field, value in fields if value is not None @@ -221,6 +226,7 @@ def update( edit_history = [] if edit_history is None else edit_history edit_history.append( { + "author_id": editing_user_id, "original_body": original_body, "reason_code": edit_reason_code, "editor_username": self.get_author_username(editing_user_id), @@ -330,5 +336,19 @@ def update_comment_count_in_comment_thread( Returns: None. """ - update_comment_count_query = {"$inc": {"comment_count": count}} + update_comment_count_query = { + "$inc": {"comment_count": count}, + "$set": {"last_activity_at": datetime.now()}, + } CommentThread().update_count(comment_thread_id, update_comment_count_query) + + def get_sk(self, _id: str, parent_id: Optional[str]) -> str: + """Returns sk field.""" + if parent_id is not None: + return f"{parent_id}-{_id}" + return f"{_id}" + + def update_sk(self, _id: str, parent_id: Optional[str]) -> None: + """Updates sk field.""" + sk = self.get_sk(_id, parent_id) + self.update(_id, sk=sk) diff --git a/forum/models/contents.py b/forum/models/contents.py index 94aefc55..b4c130a2 100644 --- a/forum/models/contents.py +++ b/forum/models/contents.py @@ -50,7 +50,11 @@ def list(self, **kwargs: Any) -> Any: """ if self.content_type: kwargs["_type"] = self.content_type - return self._collection.find(kwargs) + result = self._collection.find(kwargs) + sort = kwargs.pop("sort", None) + if sort: + return result.sort("sk", sort) + return result @classmethod def get_votes_dict(cls, up: List[str], down: List[str]) -> dict[str, Any]: diff --git a/forum/models/model_utils.py b/forum/models/model_utils.py index 938bca1b..8149b9e3 100644 --- a/forum/models/model_utils.py +++ b/forum/models/model_utils.py @@ -4,6 +4,8 @@ from bson import ObjectId from django.core.exceptions import ObjectDoesNotExist +from rest_framework import status +from rest_framework.response import Response from forum.models import Comment, CommentThread, Contents, Subscriptions, Users @@ -332,23 +334,23 @@ def get_read_states( whether the thread is read and the unread comment count. """ read_states = {} - user = Users().find_one({"_id": user_id, "read_states.course_id": course_id}) - read_state = user["read_states"][0] if user else {} - - if read_state: - read_dates = read_state.get("last_read_times", {}) - for thread in threads: - thread_key = str(thread["_id"]) - if thread_key in read_dates: - is_read = read_dates[thread_key] >= thread["last_activity_at"] - unread_comment_count = Contents().count_documents( - { - "comment_thread_id": ObjectId(thread_key), - "created_at": {"$gte": read_dates[thread_key]}, - "author_id": {"$ne": str(user_id)}, - } - ) - read_states[thread_key] = [is_read, unread_comment_count] + if user_id: + user = Users().find_one({"_id": user_id, "read_states.course_id": course_id}) + read_state = user["read_states"][0] if user else {} + if read_state: + read_dates = read_state.get("last_read_times", {}) + for thread in threads: + thread_key = str(thread["_id"]) + if thread_key in read_dates: + is_read = read_dates[thread_key] >= thread["last_activity_at"] + unread_comment_count = Contents().count_documents( + { + "comment_thread_id": ObjectId(thread_key), + "created_at": {"$gte": read_dates[thread_key]}, + "author_id": {"$ne": str(user_id)}, + } + ) + read_states[thread_key] = [is_read, unread_comment_count] return read_states @@ -776,3 +778,110 @@ def subscribe_user( def unsubscribe_user(user_id: str, source_id: str) -> None: """Unsubscribe a user from a source.""" Subscriptions().delete_subscription(user_id, source_id) + + +def delete_comments_of_a_thread(thread_id: str) -> None: + """Delete comments of a thread.""" + for comment in Comment().list( + comment_thread_id=ObjectId(thread_id), + depth=0, + parent_id=None, + ): + Comment().delete(comment["_id"]) + + +def validate_params( + params: dict[str, Any], user_id: Optional[str] = None +) -> Response | None: + """ + Validate the request parameters. + + Args: + params (dict): The request parameters. + user_id (optional[str]): The Id of the user for validation. + + Returns: + Response: A Response object with an error message if doesn't exist. + """ + valid_params = [ + "course_id", + "author_id", + "thread_type", + "flagged", + "unread", + "unanswered", + "unresponded", + "count_flagged", + "sort_key", + "page", + "per_page", + "request_id", + ] + if not user_id: + valid_params.append("user_id") + if "user_id" not in params: + return Response( + {"error": "Missing required parameter: user_id"}, + status=status.HTTP_400_BAD_REQUEST, + ) + user_id = params.get("user_id") + + for key in params: + if key not in valid_params: + return Response( + {"error": f"Invalid parameter: {key}"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if "course_id" not in params: + return Response( + {"error": "Missing required parameter: course_id"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + if user_id: + user = Users().get(user_id) + if not user: + return Response( + {"error": "User doesn't exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + return None + + +def get_threads( + params: dict[str, Any], + user_id: str, + serializer: Any, + thread_ids: list[str], + include_context: Optional[bool] = False, +) -> dict[str, Any]: + """get subscribed or all threads of a specific course for a specific user.""" + threads = handle_threads_query( + thread_ids, + user_id, + params["course_id"], + get_group_ids_from_params(params), + params.get("author_id", ""), + params.get("thread_type"), + bool(params.get("flagged", False)), + bool(params.get("unread", False)), + bool(params.get("unanswered", False)), + bool(params.get("unresponded", False)), + bool(params.get("count_flagged", False)), + params.get("sort_key", ""), + int(params.get("page", 1)), + int(params.get("per_page", 100)), + ) + context: dict[str, Any] = {} + if include_context: + context = { + "include_endorsed": True, + "include_read_state": True, + } + if user_id: + context["user_id"] = user_id + serializer = serializer(threads.pop("collection"), many=True, context=context) + threads["collection"] = serializer.data + return threads diff --git a/forum/models/threads.py b/forum/models/threads.py index 2fe7603d..da7e97af 100644 --- a/forum/models/threads.py +++ b/forum/models/threads.py @@ -6,6 +6,7 @@ from bson import ObjectId from forum.models.contents import BaseContents +from forum.models.users import Users from forum.utils import get_handler_by_name @@ -89,7 +90,7 @@ def insert( course_id: str, commentable_id: str, author_id: str, - author_username: str, + author_username: Optional[str] = None, anonymous: bool = False, anonymous_to_peers: bool = False, thread_type: str = "discussion", @@ -150,7 +151,7 @@ def insert( "anonymous_to_peers": anonymous_to_peers, "closed": False, "author_id": author_id, - "author_username": author_username, + "author_username": author_username or self.get_author_username(author_id), "created_at": date, "updated_at": date, "last_activity_at": date, @@ -188,6 +189,10 @@ def update( pinned: Optional[bool] = None, comments_count: Optional[int] = None, endorsed: Optional[bool] = None, + edit_history: Optional[list[dict[str, Any]]] = None, + original_body: Optional[str] = None, + editing_user_id: Optional[str] = None, + edit_reason_code: Optional[str] = None, ) -> int: """ Updates a thread document in the database. @@ -241,6 +246,18 @@ def update( update_data: dict[str, Any] = { field: value for field, value in fields if value is not None } + if editing_user_id: + edit_history = [] if edit_history is None else edit_history + edit_history.append( + { + "author_id": editing_user_id, + "original_body": original_body, + "reason_code": edit_reason_code, + "editor_username": self.get_author_username(editing_user_id), + "created_at": datetime.now(), + } + ) + update_data["edit_history"] = edit_history date = datetime.now() update_data["updated_at"] = date @@ -254,3 +271,8 @@ def update( sender=self.__class__, comment_thread_id=thread_id ) return result.modified_count + + def get_author_username(self, author_id: str) -> str | None: + """Return username for the respective author_id(user_id)""" + user = Users().get(author_id) + return user.get("username") if user else None diff --git a/forum/serializers/comment.py b/forum/serializers/comment.py index 39928d19..b8fd4270 100644 --- a/forum/serializers/comment.py +++ b/forum/serializers/comment.py @@ -4,10 +4,13 @@ from typing import Any +from bson import ObjectId from rest_framework import serializers +from forum.models import Comment from forum.serializers.contents import ContentSerializer from forum.serializers.custom_datetime import CustomDateTimeField +from forum.utils import prepare_comment_data_for_get_children class EndorsementSerializer(serializers.Serializer[dict[str, Any]]): @@ -49,11 +52,12 @@ class CommentSerializer(ContentSerializer): endorsed = serializers.BooleanField(default=False) depth = serializers.IntegerField(default=0) - thread_id = serializers.CharField() + thread_id = serializers.CharField(source="comment_thread_id") parent_id = serializers.CharField(default=None, allow_null=True) child_count = serializers.IntegerField(default=0) - sk = serializers.SerializerMethodField() + sk = serializers.CharField(default=None, required=False, allow_null=True) endorsement = EndorsementSerializer(default=None, required=False, allow_null=True) + children = serializers.SerializerMethodField() def __init__(self, *args: Any, **kwargs: Any) -> None: exclude_fields = kwargs.pop("exclude_fields", None) @@ -62,22 +66,33 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: for field in exclude_fields: self.fields.pop(field, None) + def get_children(self, obj: Any) -> list[dict[str, Any]]: + """Get comments of a thread.""" + if not self.context.get("recursive", False): + return [] + + children = list( + Comment().list( + parent_id=ObjectId(obj["_id"]), + depth=1, + sort=self.context.get("sort", -1), + ) + ) + children_data = prepare_comment_data_for_get_children(children) + serializer = CommentSerializer( + children_data, + many=True, + context={"recursive": False}, + exclude_fields=["sk"], + ) + return list(serializer.data) + def to_representation(self, instance: Any) -> dict[str, Any]: comment = super().to_representation(instance) if comment["parent_id"] == "None": comment["parent_id"] = None return comment - def get_sk(self, obj: dict[str, Any]) -> str: - """Return sk field""" - is_child = obj.get("parent_id") - if is_child is not None: - return "{parent_id}-{id}".format( - parent_id=obj.get("parent_id"), id=obj.get("_id") - ) - else: - return "{id}".format(id=obj.get("_id")) - def create(self, validated_data: dict[str, Any]) -> Any: """Raise NotImplementedError""" raise NotImplementedError diff --git a/forum/serializers/thread.py b/forum/serializers/thread.py index e857b94e..6759c55c 100644 --- a/forum/serializers/thread.py +++ b/forum/serializers/thread.py @@ -4,16 +4,22 @@ from typing import Any, Optional +from bson import ObjectId +from pymongo import ASCENDING, DESCENDING from rest_framework import serializers +from rest_framework.serializers import ValidationError +from forum.models import Comment from forum.models.model_utils import ( get_abuse_flagged_count, get_endorsed, get_read_states, get_username_from_id, ) +from forum.serializers.comment import CommentSerializer from forum.serializers.contents import ContentSerializer from forum.serializers.custom_datetime import CustomDateTimeField +from forum.utils import prepare_comment_data_for_get_children class ThreadSerializer(ContentSerializer): @@ -83,11 +89,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: - 'include_endorsed' (bool): Whether to include endorsement status. - 'include_read_state' (bool): Whether to include read state information. """ - context = kwargs.pop("context", {}) - self.with_responses = context.pop("with_responses", False) - self.count_flagged = context.pop("count_flagged", False) - self.include_endorsed = context.pop("include_endorsed", False) - self.include_read_state = context.pop("include_read_state", False) + self.context_data = kwargs.get("context", {}) + self.with_responses = self.context_data.pop("with_responses", False) + self.count_flagged = self.context_data.pop("count_flagged", False) + self.include_endorsed = self.context_data.pop("include_endorsed", False) + self.include_read_state = self.context_data.pop("include_read_state", False) # Customize fields based on context if not self.with_responses: @@ -119,9 +125,9 @@ def get_read(self, obj: dict[str, Any]) -> Optional[bool]: Optional[bool]: True if the thread is read, otherwise False or None. """ if self.include_read_state: - if isinstance(obj, dict) and obj.get("read") is not None: + if isinstance(obj, dict) and obj.get("read") is None: return obj.get("read", True) - user_id = obj["author_id"] + user_id = self.context_data.get("user_id", None) course_id = obj["course_id"] thread_key = obj["_id"] is_read, _ = get_read_states([obj], user_id, course_id).get( @@ -141,9 +147,9 @@ def get_unread_comments_count(self, obj: dict[str, Any]) -> Optional[int]: Optional[int]: The number of unread comments or None. """ if self.include_read_state: - if isinstance(obj, dict) and obj.get("unread_comments_count") is not None: + if isinstance(obj, dict) and obj.get("unread_comments_count") is None: return obj.get("unread_comments_count", 0) - user_id = obj["author_id"] + user_id = self.context_data.get("user_id", None) course_id = obj["course_id"] thread_key = obj["_id"] _, unread_count = get_read_states([obj], user_id, course_id).get( @@ -197,8 +203,32 @@ def get_children(self, obj: dict[str, Any]) -> Optional[Any]: Optional[Any]: The responses or children related to the thread, or None if not included. """ if self.with_responses: - # Implement when needed - print(obj) + sorting_order = ( + DESCENDING + if self.context_data.get("reverse_order", True) + else ASCENDING + ) + children = list( + Comment().list( + comment_thread_id=ObjectId(obj["_id"]), + depth=0, + parent_id=None, + sort=sorting_order, + ) + ) + children_data = prepare_comment_data_for_get_children(children) + serializer = CommentSerializer( + data=children_data, + many=True, + context={ + "recursive": self.context_data.get("recursive", False), + "sort": sorting_order, + }, + exclude_fields=["sk"], + ) + if not serializer.is_valid(raise_exception=True): + raise ValidationError(serializer.errors) + return serializer.data return [] def get_resp_total(self, obj: dict[str, Any]) -> int: @@ -212,8 +242,8 @@ def get_resp_total(self, obj: dict[str, Any]) -> int: int: The total number of responses, defaulting to 0 if not included. """ if self.with_responses: - # Implement when needed - print(obj) + children = self.get_children(obj) or [] + return len(children) return 0 def to_representation(self, instance: dict[str, Any]) -> dict[str, Any]: diff --git a/forum/urls.py b/forum/urls.py index 856f7187..2542af2c 100644 --- a/forum/urls.py +++ b/forum/urls.py @@ -14,6 +14,7 @@ ThreadSubscriptionAPIView, UserSubscriptionAPIView, ) +from forum.views.threads import CreateThreadAPIView, ThreadsAPIView, UserThreadsAPIView from forum.views.votes import CommentVoteView, ThreadVoteView api_patterns = [ @@ -52,17 +53,17 @@ CommentsAPIView.as_view(), name="comments-api", ), + path( + "threads//comments", + CreateThreadCommentAPIView.as_view(), + name="create-parent-comment-api", + ), # search threads API path( "search/threads", SearchThreadsView.as_view(), name="search-thread-api", ), - path( - "threads//comments", - CreateThreadCommentAPIView.as_view(), - name="create-parent-comment-api", - ), # subscription APIs path( "users//subscriptions", @@ -79,6 +80,22 @@ ThreadSubscriptionAPIView.as_view(), name="thread-subscriptions", ), + # threads API + path( + "course/threads", + CreateThreadAPIView.as_view(), + name="create-thread-api", + ), + path( + "threads", + UserThreadsAPIView.as_view(), + name="user-threads-api", + ), + path( + "threads/", + ThreadsAPIView.as_view(), + name="threads-api", + ), # Proxy view for various API endpoints path( "", diff --git a/forum/utils.py b/forum/utils.py index e52f1892..acca3e54 100644 --- a/forum/utils.py +++ b/forum/utils.py @@ -99,3 +99,23 @@ def get_handler_by_name(name: str) -> Signal: return map_signals[name] except KeyError as exc: raise KeyError(f"No signal found for the name: {name}") from exc + + +def prepare_comment_data_for_get_children( + children: list[dict[str, Any]] +) -> list[dict[str, Any]]: + """Prepare children data to be used in serializer.""" + children_data = [] + for child in children: + children_data.append( + { + **child, + "id": str(child.get("_id")), + "user_id": child.get("author_id"), + "thread_id": str(child.get("comment_thread_id")), + "username": child.get("author_username"), + "parent_id": str(child.get("parent_id")), + "type": str(child.get("_type", "")).lower(), + } + ) + return children_data diff --git a/forum/views/comments.py b/forum/views/comments.py index 9615ca4c..96459a1a 100644 --- a/forum/views/comments.py +++ b/forum/views/comments.py @@ -74,6 +74,9 @@ def prepare_comment_api_response( "parent_id": str(comment.get("parent_id")), "type": str(comment.get("_type", "")).lower(), } + if not exclude_fields: + exclude_fields = [] + exclude_fields.append("children") serializer = CommentSerializer( data=comment_data, exclude_fields=exclude_fields, @@ -114,7 +117,7 @@ def get(self, request: Request, comment_id: str) -> Response: ) data = prepare_comment_api_response( comment, - exclude_fields=["sk", "endorsement"], + exclude_fields=["sk"], ) return Response(data, status=status.HTTP_200_OK) diff --git a/forum/views/subscriptions.py b/forum/views/subscriptions.py index 0cafaf96..736cad35 100644 --- a/forum/views/subscriptions.py +++ b/forum/views/subscriptions.py @@ -11,10 +11,10 @@ from forum.models import CommentThread, Subscriptions, Users from forum.models.model_utils import ( find_subscribed_threads, - get_group_ids_from_params, - handle_threads_query, + get_threads, subscribe_user, unsubscribe_user, + validate_params, ) from forum.pagination import ForumPagination from forum.serializers.subscriptions import SubscriptionSerializer @@ -125,79 +125,15 @@ def get(self, request: Request, user_id: str) -> Response: Raises: HTTP_400_BAD_REQUEST: If the user does not exist. """ - user = Users().get(user_id) - if not user: - return Response( - {"error": "User doesn't exist"}, - status=status.HTTP_400_BAD_REQUEST, - ) - params = request.GET.dict() - validations = self._validate_params(params) + validations = validate_params(params, user_id) if validations: return validations - threads = handle_threads_query( - find_subscribed_threads(user_id, params["course_id"]), - user_id, - params["course_id"], - get_group_ids_from_params(params), - params.get("author_id", ""), - params.get("thread_type"), - bool(params.get("flagged", False)), - bool(params.get("unread", False)), - bool(params.get("unanswered", False)), - bool(params.get("unresponded", False)), - bool(params.get("count_flagged", False)), - params.get("sort_key", ""), - int(params.get("page", 1)), - int(params.get("per_page", 100)), - ) - serializer = ThreadSerializer(threads.pop("collection"), many=True) - threads["collection"] = serializer.data + thread_ids = find_subscribed_threads(user_id, params["course_id"]) + threads = get_threads(params, user_id, ThreadSerializer, thread_ids) return Response(data=threads, status=status.HTTP_200_OK) - def _validate_params(self, params: dict[str, Any]) -> Response | None: - """ - Validate the request parameters. - - Args: - params (dict): The request parameters. - - Returns: - Response: A Response object with an error message if doesn't exist. - """ - - valid_params = [ - "course_id", - "author_id", - "thread_type", - "flagged", - "unread", - "unanswered", - "unresponded", - "count_flagged", - "sort_key", - "page", - "per_page", - "request_id", - ] - - for key in params: - if key not in valid_params: - return Response( - {"error": f"Invalid parameter: {key}"}, - status=status.HTTP_400_BAD_REQUEST, - ) - - if "course_id" not in params: - return Response( - {"error": "Missing required parameter: course_id"}, - status=status.HTTP_400_BAD_REQUEST, - ) - - return None - class ThreadSubscriptionAPIView(APIView): """ diff --git a/forum/views/threads.py b/forum/views/threads.py new file mode 100644 index 00000000..234581bc --- /dev/null +++ b/forum/views/threads.py @@ -0,0 +1,343 @@ +"""Forum Threads API Views.""" + +import logging +from typing import Any, Optional + +from django.core.exceptions import ObjectDoesNotExist +from rest_framework import status +from rest_framework.permissions import AllowAny +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.serializers import ValidationError +from rest_framework.views import APIView + +from forum.models.comments import Comment +from forum.models.model_utils import ( + delete_comments_of_a_thread, + get_threads, + validate_object, + validate_params, +) +from forum.models.threads import CommentThread +from forum.serializers.thread import ThreadSerializer +from forum.utils import get_int_value_from_collection, str_to_bool + +log = logging.getLogger(__name__) + + +def get_thread_data(thread: dict[str, Any]) -> dict[str, Any]: + """Prepare thread data for the api response.""" + _type = str(thread.get("_type", "")).lower() + thread_data = { + **thread, + "id": str(thread.get("_id")), + "type": "thread" if _type == "commentthread" else _type, + "user_id": thread.get("author_id"), + "username": str(thread.get("author_username")), + "comments_count": thread["comment_count"], + } + return thread_data + + +def prepare_thread_api_response( + thread: dict[str, Any], + include_context: Optional[bool] = False, + data_or_params: Optional[dict[str, Any]] = None, + include_data_from_params: Optional[bool] = False, +) -> dict[str, Any] | None: + """Serialize thread data for the api response.""" + thread_data = get_thread_data(thread) + + context = {} + if include_context: + context = { + "include_endorsed": True, + "include_read_state": True, + } + if data_or_params: + if include_data_from_params: + thread_data["resp_skip"] = get_int_value_from_collection( + data_or_params, "resp_skip", 0 + ) + thread_data["resp_limit"] = get_int_value_from_collection( + data_or_params, "resp_limit", 100 + ) + context["recursive"] = str_to_bool( + data_or_params.get("recursive", "False") + ) + context["with_responses"] = str_to_bool( + data_or_params.get("with_responses", "True") + ) + context["mark_as_read"] = str_to_bool( + data_or_params.get("mark_as_read", "False") + ) + context["reverse_order"] = str_to_bool( + data_or_params.get("reverse_order", "True") + ) + context["merge_question_type_responses"] = str_to_bool( + data_or_params.get("merge_question_type_responses", "False") + ) + + if user_id := data_or_params.get("user_id"): + context["user_id"] = user_id + + serializer = ThreadSerializer( + data=thread_data, + context=context, + ) + if not serializer.is_valid(raise_exception=True): + log.error(f"validation error in thread API call: {serializer.errors}") + raise ValidationError(serializer.errors) + + return serializer.data + + +class ThreadsAPIView(APIView): + """ + API view to handle operations related to threads. + + This view uses the CommentThread model for database interactions and the ThreadSerializer + for serializing and deserializing data. + """ + + permission_classes = (AllowAny,) + + def get(self, request: Request, thread_id: str) -> Response: + """ + Retrieve a thread by its ID. + + Args: + request: The HTTP request object. + thread_id: The ID of the thread to retrieve. + + Returns: + Response: A Response object containing the serialized thread data or an error message. + """ + try: + thread = validate_object(CommentThread, thread_id) + except ObjectDoesNotExist: + return Response( + {"error": "Thread does not exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + params = request.query_params + try: + serialized_data = prepare_thread_api_response(thread, True, params, True) + return Response(serialized_data) + + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + def delete(self, request: Request, thread_id: str) -> Response: + """ + Deletes a thread by it's ID. + + Parameters: + request (Request): The incoming request. + thread_id: The ID of the thread to be deleted. + Body: + Empty. + Response: + The details of the thread that is deleted. + """ + try: + thread = validate_object(CommentThread, thread_id) + except ObjectDoesNotExist: + return Response( + {"error": "thread does not exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + delete_comments_of_a_thread(thread_id) + thread = validate_object(CommentThread, thread_id) + + try: + serialized_data = prepare_thread_api_response(thread) + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + CommentThread().delete(thread_id) + + return Response(serialized_data, status=status.HTTP_200_OK) + + def put(self, request: Request, thread_id: str) -> Response: + """ + Updates an existing thread. + + Parameters: + request (Request): The incoming request. + thread_id: The ID of the thread to be edited. + Body: + fields to be updated. + Response: + The details of the thread that is updated. + """ + try: + thread = validate_object(Comment, thread_id) + except ObjectDoesNotExist: + return Response( + {"error": "thread does not exist"}, + status=status.HTTP_400_BAD_REQUEST, + ) + data = request.data + update_thread_data: dict[str, Any] = self._get_update_thread_data(data) + if thread: + update_thread_data["edit_history"] = thread.get("edit_history", []) + update_thread_data["original_body"] = thread.get("body") + + CommentThread().update(thread_id, **update_thread_data) + updated_thread = CommentThread().get(thread_id) + try: + if updated_thread: + serialized_data = prepare_thread_api_response( + updated_thread, + True, + data, + ) + return Response(serialized_data, status=status.HTTP_200_OK) + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return Response( + {"error": "Thread is not updated"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + def _get_update_thread_data(self, data: dict[str, Any]) -> dict[str, Any]: + """convert request data to a dict excluding empty data""" + fields = [ + ("title", data.get("title")), + ("body", data.get("body")), + ("course_id", data.get("course_id")), + ("anonymous", str_to_bool(data.get("anonymous", "False"))), + ( + "anonymous_to_peers", + str_to_bool(data.get("anonymous_to_peers", "False")), + ), + ("closed", str_to_bool(data.get("closed", "False"))), + ("commentable_id", data.get("commentable_id", "course")), + ("author_id", data.get("user_id")), + ("editing_user_id", data.get("editing_user_id")), + ("pinned", str_to_bool(data.get("pinned", "False"))), + ("thread_type", data.get("thread_type", "discussion")), + ("edit_reason_code", data.get("edit_reason_code")), + ] + return {field: value for field, value in fields if value is not None} + + +class CreateThreadAPIView(APIView): + """ + API view to create a new thread. + + This view uses the CommentThread model for database interactions and the ThreadSerializer + for serializing and deserializing data. + """ + + permission_classes = (AllowAny,) + + def post(self, request: Request) -> Response: + """ + Create a new thread. + + Parameters: + request (Request): The incoming request. + Body: + fields to be added in a new thread. + Response: + The details of the thread that is created. + """ + data = request.data + try: + self.validate_request_data(data) + except ValueError as error: + return Response( + {"error": str(error)}, + status=status.HTTP_400_BAD_REQUEST, + ) + + thread = self.create_thread(data) + try: + serialized_data = prepare_thread_api_response(thread, True, data) + except ValidationError as error: + return Response( + error.detail, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return Response(serialized_data, status=status.HTTP_200_OK) + + def validate_request_data(self, data: dict[str, Any]) -> None: + """ + Validates the request data if it exists or not. + + Parameters: + data: request data to validate. + Response: + raise exception if some data does not exists. + """ + fields_to_validate = ["title", "body", "course_id", "user_id"] + for field in fields_to_validate: + if field not in data or not data[field]: + raise ValueError(f"{field} is missing.") + + def create_thread(self, data: dict[str, Any]) -> Any: + """handle thread creation and returns a thread.""" + new_comment_id = CommentThread().insert( + title=data["title"], + body=data["body"], + course_id=data["course_id"], + anonymous=str_to_bool(data.get("anonymous", "False")), + anonymous_to_peers=str_to_bool(data.get("anonymous_to_peers", "False")), + author_id=data["user_id"], + commentable_id=data.get("commentable_id", "course"), + thread_type=data.get("thread_type", "discussion"), + ) + return CommentThread().get(new_comment_id) + + +class UserThreadsAPIView(APIView): + """ + API View for getting all threads of a course. + + This view provides an endpoint for retrieving all threads based on course id. + """ + + permission_classes = (AllowAny,) + + def get(self, request: Request) -> Response: + """ + Retrieve a course's threads. + + Args: + request (HttpRequest): The HTTP request object. + + Returns: + Response: A Response object with the threads data. + + Raises: + HTTP_400_BAD_REQUEST: If the user does not exist. + """ + params = request.GET.dict() + validations = validate_params(params) + if validations: + return validations + + user_id = params.get("user_id", "") + course_id = params.get("course_id") + thread_filter = { + "_type": {"$in": [CommentThread.content_type]}, + "course_id": {"$in": [course_id]}, + } + filtered_threads = CommentThread().find(thread_filter) + thread_ids = [thread["_id"] for thread in filtered_threads] + threads = get_threads(params, user_id, ThreadSerializer, thread_ids, True) + return Response(data=threads, status=status.HTTP_200_OK)