diff --git a/forum/api/__init__.py b/forum/api/__init__.py index 55a7c4a3..f011f620 100644 --- a/forum/api/__init__.py +++ b/forum/api/__init__.py @@ -7,6 +7,7 @@ create_child_comment, create_parent_comment, delete_comment, + get_course_id_by_comment, get_parent_comment, update_comment, ) @@ -25,6 +26,7 @@ from .threads import ( create_thread, delete_thread, + get_course_id_by_thread, get_thread, get_user_threads, update_thread, @@ -59,6 +61,8 @@ "delete_thread", "delete_thread_vote", "get_commentables_stats", + "get_course_id_by_comment", + "get_course_id_by_thread", "get_parent_comment", "get_thread", "get_thread_subscriptions", diff --git a/forum/api/comments.py b/forum/api/comments.py index 63344382..e2440577 100644 --- a/forum/api/comments.py +++ b/forum/api/comments.py @@ -11,6 +11,7 @@ from forum.backends.mongodb.api import ( create_comment, delete_comment_by_id, + get_course_id_by_comment_id, get_thread_by_id, get_thread_id_by_comment_id, get_user_by_id, @@ -21,6 +22,7 @@ ) from forum.backends.mongodb.comments import Comment from forum.backends.mongodb.threads import CommentThread +from forum.backends.mongodb import api from forum.serializers.comment import CommentSerializer from forum.utils import ForumV2RequestError @@ -294,3 +296,15 @@ def create_parent_comment( ) except ValidationError as error: raise error + + +def get_course_id_by_comment(comment_id: str) -> str | None: + """ + Return course_id for the matching comment. + It searches for comment_id both in mongodb and mysql. + """ + return ( + get_course_id_by_comment_id(comment_id) + or api.get_course_id_by_comment_id(comment_id) + or None + ) diff --git a/forum/api/search.py b/forum/api/search.py index e09d887b..789dbd92 100644 --- a/forum/api/search.py +++ b/forum/api/search.py @@ -14,8 +14,7 @@ def _get_thread_ids_from_indexes( context: str, group_ids: list[int], text: str, - commentable_id: Optional[str] = None, - commentable_ids: Optional[str] = None, + commentable_ids: Optional[list[str]] = None, course_id: Optional[str] = None, ) -> tuple[list[str], Optional[str]]: """ @@ -39,7 +38,6 @@ def _get_thread_ids_from_indexes( context, group_ids, text, - commentable_id=commentable_id, commentable_ids=commentable_ids, course_id=course_id, ) @@ -50,7 +48,6 @@ def _get_thread_ids_from_indexes( context, group_ids, corrected_text, - commentable_id=commentable_id, commentable_ids=commentable_ids, course_id=course_id, ) @@ -62,28 +59,30 @@ def _get_thread_ids_from_indexes( def search_threads( text: str, - sort_key: str, - context: str, user_id: str, course_id: str, - group_ids: list[int], - author_id: str, - thread_type: str, - flagged: bool, - unread: bool, - unanswered: bool, - unresponded: bool, - count_flagged: bool, - commentable_id: str, - commentable_ids: str, + group_ids: Optional[list[int]] = None, + commentable_ids: Optional[list[str]] = None, + author_id: Optional[str] = None, + thread_type: Optional[str] = None, + sort_key: str = "date", + context: str = "course", + flagged: bool = False, + unread: bool = False, + unanswered: bool = False, + unresponded: bool = False, + count_flagged: bool = False, page: int = FORUM_DEFAULT_PAGE, per_page: int = FORUM_DEFAULT_PER_PAGE, ) -> dict[str, Any]: """ Search for threads based on the provided data. """ + group_ids = group_ids or [] + commentable_ids = commentable_ids or [] + thread_ids, corrected_text = _get_thread_ids_from_indexes( - context, group_ids, text, commentable_id, commentable_ids, course_id + context, group_ids, text, commentable_ids, course_id ) data = handle_threads_query( diff --git a/forum/api/threads.py b/forum/api/threads.py index 4a1fd680..cf070913 100644 --- a/forum/api/threads.py +++ b/forum/api/threads.py @@ -11,6 +11,7 @@ from forum.backends.mongodb.api import ( delete_comments_of_a_thread, delete_subscriptions_of_a_thread, + get_course_id_by_thread_id, get_threads, ) from forum.backends.mongodb.api import mark_as_read as mark_thread_as_read @@ -21,6 +22,7 @@ ) from forum.backends.mongodb.threads import CommentThread from forum.backends.mongodb.users import Users +from forum.backends.mysql import api from forum.serializers.thread import ThreadSerializer from forum.utils import ForumV2RequestError, get_int_value_from_collection, str_to_bool @@ -42,6 +44,7 @@ def _get_thread_data_from_request_data(data: dict[str, Any]) -> dict[str, Any]: "close_reason_code", "endorsed", "pinned", + "group_id", ] result = {field: data.get(field) for field in fields if data.get(field) is not None} @@ -282,6 +285,7 @@ def create_thread( anonymous_to_peers: bool = False, commentable_id: str = "course", thread_type: str = "discussion", + group_id: Optional[int] = None, ) -> dict[str, Any]: """ Create a new thread. @@ -295,6 +299,7 @@ def create_thread( closed: Whether the thread is closed. commentable_id: The ID of the commentable. user_id: The ID of the user. + group_id: The ID of the group. Response: The details of the thread that is created. """ @@ -307,6 +312,7 @@ def create_thread( "anonymous_to_peers": anonymous_to_peers, "commentable_id": commentable_id, "thread_type": thread_type, + "group_id": group_id, } thread_data: dict[str, Any] = _get_thread_data_from_request_data(data) @@ -376,3 +382,15 @@ def get_user_threads( threads = get_threads(params, ThreadSerializer, thread_ids, user_id or "") return threads + + +def get_course_id_by_thread(thread_id: str) -> str | None: + """ + Return course_id for the matching thread. + It searches for thread_id both in mongodb and mysql. + """ + return ( + get_course_id_by_thread_id(thread_id) + or api.get_course_id_by_thread_id(thread_id) + or None + ) diff --git a/forum/backends/mongodb/api.py b/forum/backends/mongodb/api.py index 90d6b7c2..a978b454 100644 --- a/forum/backends/mongodb/api.py +++ b/forum/backends/mongodb/api.py @@ -975,6 +975,7 @@ def user_to_hash( hash_data = {} hash_data["username"] = user["username"] hash_data["external_id"] = user["external_id"] + hash_data["id"] = user["external_id"] comment_model = Comment() thread_model = CommentThread() @@ -1410,3 +1411,23 @@ def get_thread_id_by_comment_id(parent_comment_id: str) -> str: if parent_comment: return parent_comment["comment_thread_id"] raise ValueError("Comment doesn't have the thread.") + + +def get_course_id_by_thread_id(thread_id: str) -> str | None: + """ + Return course_id for the matching thread. + """ + thread = CommentThread().get(thread_id) + if thread: + return thread.get("course_id") + return None + + +def get_course_id_by_comment_id(comment_id: str) -> str | None: + """ + Return course_id for the matching comment. + """ + comment = Comment().get(comment_id) + if comment: + return comment.get("course_id") + return None diff --git a/forum/backends/mysql/api.py b/forum/backends/mysql/api.py index 3f7e52a8..dd52037e 100644 --- a/forum/backends/mysql/api.py +++ b/forum/backends/mysql/api.py @@ -1113,3 +1113,23 @@ def find_or_create_user(user_id: str) -> str: """Find or create user.""" user, _ = ForumUser.objects.get_or_create(user__pk=user_id) return user.pk + + +def get_course_id_by_thread_id(thread_id: str) -> str | None: + """ + Return course_id for the matching thread. + """ + thread = CommentThread.objects.filter(id=thread_id).first() + if thread: + return thread.course_id + return None + + +def get_course_id_by_comment_id(comment_id: str) -> str | None: + """ + Return course_id for the matching comment. + """ + comment = Comment.objects.filter(id=comment_id).first() + if comment: + return comment.course_id + return None diff --git a/forum/search/comment_search.py b/forum/search/comment_search.py index 1d9ff752..e4013580 100644 --- a/forum/search/comment_search.py +++ b/forum/search/comment_search.py @@ -93,19 +93,19 @@ class ThreadSearch(CommentSearch): def build_must_clause( self, search_text: str, - commentable_id: Optional[str] = None, - commentable_ids: Optional[str] = None, + commentable_ids: Optional[list[str]] = None, course_id: Optional[str] = None, ) -> list[dict[str, Any]]: """ Build the 'must' clause for thread-specific Elasticsearch queries based on input parameters. """ must: list[dict[str, Any]] = [] + commentable_ids = commentable_ids or [] - if commentable_id: - must.append({"term": {"commentable_id": commentable_id}}) - if commentable_ids: - must.append({"terms": {"commentable_id": commentable_ids.split(",")}}) + if len(commentable_ids) == 1: + must.append({"term": {"commentable_id": commentable_ids[0]}}) + elif len(commentable_ids) > 1: + must.append({"terms": {"commentable_id": commentable_ids}}) if course_id: must.append({"term": {"course_id": course_id}}) @@ -162,15 +162,14 @@ def get_thread_ids( group_ids: list[int], search_text: str, sort_criteria: Optional[list[dict[str, str]]] = None, - commentable_id: Optional[str] = None, - commentable_ids: Optional[str] = None, + commentable_ids: Optional[list[str]] = None, course_id: Optional[str] = None, ) -> list[str]: """ Retrieve thread IDs based on search criteria. """ must_clause: list[dict[str, Any]] = self.build_must_clause( - search_text, commentable_id, commentable_ids, course_id + search_text, commentable_ids, course_id ) filter_clause: list[dict[str, Any]] = self.build_filter_clause( context, group_ids @@ -200,8 +199,7 @@ def get_thread_ids_with_corrected_text( group_ids: list[int], search_text: str, sort_criteria: Optional[list[dict[str, str]]] = None, - commentable_id: Optional[str] = None, - commentable_ids: Optional[str] = None, + commentable_ids: Optional[list[str]] = None, course_id: Optional[str] = None, ) -> list[str]: """ @@ -214,7 +212,6 @@ def get_thread_ids_with_corrected_text( group_ids, search_text, sort_criteria, - commentable_id, commentable_ids, course_id, ) diff --git a/forum/serializers/thread.py b/forum/serializers/thread.py index 22e60a59..38db25d7 100644 --- a/forum/serializers/thread.py +++ b/forum/serializers/thread.py @@ -94,6 +94,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.count_flagged = self.context_data.pop("count_flagged", False) self.include_endorsed = self.context_data.pop("include_endorsed", False) self.include_read_state = self.context_data.pop("include_read_state", False) + self.merge_question_type_responses = self.context_data.pop( + "merge_question_type_responses", False + ) # Customize fields based on context if not self.with_responses: @@ -270,6 +273,7 @@ def to_representation(self, instance: dict[str, Any]) -> dict[str, Any]: or self.context_data.get("recursive") is True ) and data.get("thread_type") == "question" + and not self.merge_question_type_responses ): children = data.pop("children") data["non_endorsed_responses"] = [] diff --git a/forum/utils.py b/forum/utils.py index 0b613b74..98d05ca4 100644 --- a/forum/utils.py +++ b/forum/utils.py @@ -184,6 +184,35 @@ def get_group_ids_from_params(params: dict[str, Any]) -> list[int]: return group_ids +def get_commentable_ids_from_params(params: dict[str, Any]) -> list[str]: + """ + Extract commentable IDs from the provided parameters. + + Args: + params (dict): A dictionary containing the parameters. + + Returns: + list: A list of commentable IDs. + + Raises: + ValueError: If both `commentable_id` and `commentable_ids` are specified in the parameters. + """ + if "commentable_id" in params and "commentable_ids" in params: + raise ValueError("Cannot specify both commentable_id and commentable_ids") + + commentable_id = params.get("commentable_id") + if commentable_id: + return [commentable_id] + + commentable_ids = params.get("commentable_ids", []) + if isinstance(commentable_ids, str): + return commentable_ids.split(",") + elif isinstance(commentable_ids, list): + return commentable_ids + + return [] + + def get_sort_criteria(sort_key: str) -> Sequence[tuple[str, int]]: """ Generate sorting criteria based on the provided key. diff --git a/forum/views/search.py b/forum/views/search.py index 48f16b47..785ac6bf 100644 --- a/forum/views/search.py +++ b/forum/views/search.py @@ -12,7 +12,7 @@ from forum.api.search import search_threads from forum.constants import FORUM_DEFAULT_PAGE, FORUM_DEFAULT_PER_PAGE -from forum.utils import get_group_ids_from_params +from forum.utils import get_commentable_ids_from_params, get_group_ids_from_params class SearchThreadsView(APIView): @@ -74,8 +74,7 @@ def _validate_and_extract_params(self, request: Request) -> dict[str, Any]: # Group IDs extraction params["group_ids"] = get_group_ids_from_params(data) - params["commentable_id"] = data.get("commentable_id") - params["commentable_ids"] = data.get("commentable_ids") + params["commentable_ids"] = get_commentable_ids_from_params(data) return params