Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions forum/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
create_child_comment,
create_parent_comment,
delete_comment,
get_course_id_by_comment,
get_parent_comment,
update_comment,
)
Expand All @@ -25,6 +26,7 @@
from .threads import (
create_thread,
delete_thread,
get_course_id_by_thread,
get_thread,
get_user_threads,
update_thread,
Expand Down Expand Up @@ -59,6 +61,8 @@
"delete_thread",
"delete_thread_vote",
"get_commentables_stats",
"get_course_id_by_comment",
"get_course_id_by_thread",
"get_parent_comment",
"get_thread",
"get_thread_subscriptions",
Expand Down
14 changes: 14 additions & 0 deletions forum/api/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from forum.backends.mongodb.api import (
create_comment,
delete_comment_by_id,
get_course_id_by_comment_id,
get_thread_by_id,
get_thread_id_by_comment_id,
get_user_by_id,
Expand All @@ -21,6 +22,7 @@
)
from forum.backends.mongodb.comments import Comment
from forum.backends.mongodb.threads import CommentThread
from forum.backends.mongodb import api
from forum.serializers.comment import CommentSerializer
from forum.utils import ForumV2RequestError

Expand Down Expand Up @@ -294,3 +296,15 @@ def create_parent_comment(
)
except ValidationError as error:
raise error


def get_course_id_by_comment(comment_id: str) -> str | None:
"""
Return course_id for the matching comment.
It searches for comment_id both in mongodb and mysql.
"""
return (
get_course_id_by_comment_id(comment_id)
or api.get_course_id_by_comment_id(comment_id)
or None
)
33 changes: 16 additions & 17 deletions forum/api/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def _get_thread_ids_from_indexes(
context: str,
group_ids: list[int],
text: str,
commentable_id: Optional[str] = None,
commentable_ids: Optional[str] = None,
commentable_ids: Optional[list[str]] = None,
course_id: Optional[str] = None,
) -> tuple[list[str], Optional[str]]:
"""
Expand All @@ -39,7 +38,6 @@ def _get_thread_ids_from_indexes(
context,
group_ids,
text,
commentable_id=commentable_id,
commentable_ids=commentable_ids,
course_id=course_id,
)
Expand All @@ -50,7 +48,6 @@ def _get_thread_ids_from_indexes(
context,
group_ids,
corrected_text,
commentable_id=commentable_id,
commentable_ids=commentable_ids,
course_id=course_id,
)
Expand All @@ -62,28 +59,30 @@ def _get_thread_ids_from_indexes(

def search_threads(
text: str,
sort_key: str,
context: str,
user_id: str,
course_id: str,
group_ids: list[int],
author_id: str,
thread_type: str,
flagged: bool,
unread: bool,
unanswered: bool,
unresponded: bool,
count_flagged: bool,
commentable_id: str,
commentable_ids: str,
group_ids: Optional[list[int]] = None,
commentable_ids: Optional[list[str]] = None,
author_id: Optional[str] = None,
thread_type: Optional[str] = None,
sort_key: str = "date",
context: str = "course",
flagged: bool = False,
unread: bool = False,
unanswered: bool = False,
unresponded: bool = False,
count_flagged: bool = False,
page: int = FORUM_DEFAULT_PAGE,
per_page: int = FORUM_DEFAULT_PER_PAGE,
) -> dict[str, Any]:
"""
Search for threads based on the provided data.
"""
group_ids = group_ids or []
commentable_ids = commentable_ids or []

thread_ids, corrected_text = _get_thread_ids_from_indexes(
context, group_ids, text, commentable_id, commentable_ids, course_id
context, group_ids, text, commentable_ids, course_id
)

data = handle_threads_query(
Expand Down
18 changes: 18 additions & 0 deletions forum/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from forum.backends.mongodb.api import (
delete_comments_of_a_thread,
delete_subscriptions_of_a_thread,
get_course_id_by_thread_id,
get_threads,
)
from forum.backends.mongodb.api import mark_as_read as mark_thread_as_read
Expand All @@ -21,6 +22,7 @@
)
from forum.backends.mongodb.threads import CommentThread
from forum.backends.mongodb.users import Users
from forum.backends.mysql import api
from forum.serializers.thread import ThreadSerializer
from forum.utils import ForumV2RequestError, get_int_value_from_collection, str_to_bool

Expand All @@ -42,6 +44,7 @@ def _get_thread_data_from_request_data(data: dict[str, Any]) -> dict[str, Any]:
"close_reason_code",
"endorsed",
"pinned",
"group_id",
]
result = {field: data.get(field) for field in fields if data.get(field) is not None}

Expand Down Expand Up @@ -282,6 +285,7 @@ def create_thread(
anonymous_to_peers: bool = False,
commentable_id: str = "course",
thread_type: str = "discussion",
group_id: Optional[int] = None,
) -> dict[str, Any]:
"""
Create a new thread.
Expand All @@ -295,6 +299,7 @@ def create_thread(
closed: Whether the thread is closed.
commentable_id: The ID of the commentable.
user_id: The ID of the user.
group_id: The ID of the group.
Response:
The details of the thread that is created.
"""
Expand All @@ -307,6 +312,7 @@ def create_thread(
"anonymous_to_peers": anonymous_to_peers,
"commentable_id": commentable_id,
"thread_type": thread_type,
"group_id": group_id,
}
thread_data: dict[str, Any] = _get_thread_data_from_request_data(data)

Expand Down Expand Up @@ -376,3 +382,15 @@ def get_user_threads(
threads = get_threads(params, ThreadSerializer, thread_ids, user_id or "")

return threads


def get_course_id_by_thread(thread_id: str) -> str | None:
"""
Return course_id for the matching thread.
It searches for thread_id both in mongodb and mysql.
"""
return (
get_course_id_by_thread_id(thread_id)
or api.get_course_id_by_thread_id(thread_id)
or None
)
21 changes: 21 additions & 0 deletions forum/backends/mongodb/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,7 @@ def user_to_hash(
hash_data = {}
hash_data["username"] = user["username"]
hash_data["external_id"] = user["external_id"]
hash_data["id"] = user["external_id"]

comment_model = Comment()
thread_model = CommentThread()
Expand Down Expand Up @@ -1410,3 +1411,23 @@ def get_thread_id_by_comment_id(parent_comment_id: str) -> str:
if parent_comment:
return parent_comment["comment_thread_id"]
raise ValueError("Comment doesn't have the thread.")


def get_course_id_by_thread_id(thread_id: str) -> str | None:
"""
Return course_id for the matching thread.
"""
thread = CommentThread().get(thread_id)
if thread:
return thread.get("course_id")
return None


def get_course_id_by_comment_id(comment_id: str) -> str | None:
"""
Return course_id for the matching comment.
"""
comment = Comment().get(comment_id)
if comment:
return comment.get("course_id")
return None
20 changes: 20 additions & 0 deletions forum/backends/mysql/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,3 +1113,23 @@ def find_or_create_user(user_id: str) -> str:
"""Find or create user."""
user, _ = ForumUser.objects.get_or_create(user__pk=user_id)
return user.pk


def get_course_id_by_thread_id(thread_id: str) -> str | None:
"""
Return course_id for the matching thread.
"""
thread = CommentThread.objects.filter(id=thread_id).first()
if thread:
return thread.course_id
return None


def get_course_id_by_comment_id(comment_id: str) -> str | None:
"""
Return course_id for the matching comment.
"""
comment = Comment.objects.filter(id=comment_id).first()
if comment:
return comment.course_id
return None
21 changes: 9 additions & 12 deletions forum/search/comment_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ class ThreadSearch(CommentSearch):
def build_must_clause(
self,
search_text: str,
commentable_id: Optional[str] = None,
commentable_ids: Optional[str] = None,
commentable_ids: Optional[list[str]] = None,
course_id: Optional[str] = None,
) -> list[dict[str, Any]]:
"""
Build the 'must' clause for thread-specific Elasticsearch queries based on input parameters.
"""
must: list[dict[str, Any]] = []
commentable_ids = commentable_ids or []

if commentable_id:
must.append({"term": {"commentable_id": commentable_id}})
if commentable_ids:
must.append({"terms": {"commentable_id": commentable_ids.split(",")}})
if len(commentable_ids) == 1:
must.append({"term": {"commentable_id": commentable_ids[0]}})
elif len(commentable_ids) > 1:
must.append({"terms": {"commentable_id": commentable_ids}})
if course_id:
must.append({"term": {"course_id": course_id}})

Expand Down Expand Up @@ -162,15 +162,14 @@ def get_thread_ids(
group_ids: list[int],
search_text: str,
sort_criteria: Optional[list[dict[str, str]]] = None,
commentable_id: Optional[str] = None,
commentable_ids: Optional[str] = None,
commentable_ids: Optional[list[str]] = None,
course_id: Optional[str] = None,
) -> list[str]:
"""
Retrieve thread IDs based on search criteria.
"""
must_clause: list[dict[str, Any]] = self.build_must_clause(
search_text, commentable_id, commentable_ids, course_id
search_text, commentable_ids, course_id
)
filter_clause: list[dict[str, Any]] = self.build_filter_clause(
context, group_ids
Expand Down Expand Up @@ -200,8 +199,7 @@ def get_thread_ids_with_corrected_text(
group_ids: list[int],
search_text: str,
sort_criteria: Optional[list[dict[str, str]]] = None,
commentable_id: Optional[str] = None,
commentable_ids: Optional[str] = None,
commentable_ids: Optional[list[str]] = None,
course_id: Optional[str] = None,
) -> list[str]:
"""
Expand All @@ -214,7 +212,6 @@ def get_thread_ids_with_corrected_text(
group_ids,
search_text,
sort_criteria,
commentable_id,
commentable_ids,
course_id,
)
4 changes: 4 additions & 0 deletions forum/serializers/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.count_flagged = self.context_data.pop("count_flagged", False)
self.include_endorsed = self.context_data.pop("include_endorsed", False)
self.include_read_state = self.context_data.pop("include_read_state", False)
self.merge_question_type_responses = self.context_data.pop(
"merge_question_type_responses", False
)

# Customize fields based on context
if not self.with_responses:
Expand Down Expand Up @@ -270,6 +273,7 @@ def to_representation(self, instance: dict[str, Any]) -> dict[str, Any]:
or self.context_data.get("recursive") is True
)
and data.get("thread_type") == "question"
and not self.merge_question_type_responses
):
children = data.pop("children")
data["non_endorsed_responses"] = []
Expand Down
29 changes: 29 additions & 0 deletions forum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,35 @@ def get_group_ids_from_params(params: dict[str, Any]) -> list[int]:
return group_ids


def get_commentable_ids_from_params(params: dict[str, Any]) -> list[str]:
"""
Extract commentable IDs from the provided parameters.

Args:
params (dict): A dictionary containing the parameters.

Returns:
list: A list of commentable IDs.

Raises:
ValueError: If both `commentable_id` and `commentable_ids` are specified in the parameters.
"""
if "commentable_id" in params and "commentable_ids" in params:
raise ValueError("Cannot specify both commentable_id and commentable_ids")

commentable_id = params.get("commentable_id")
if commentable_id:
return [commentable_id]

commentable_ids = params.get("commentable_ids", [])
if isinstance(commentable_ids, str):
return commentable_ids.split(",")
elif isinstance(commentable_ids, list):
return commentable_ids

return []


def get_sort_criteria(sort_key: str) -> Sequence[tuple[str, int]]:
"""
Generate sorting criteria based on the provided key.
Expand Down
5 changes: 2 additions & 3 deletions forum/views/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from forum.api.search import search_threads
from forum.constants import FORUM_DEFAULT_PAGE, FORUM_DEFAULT_PER_PAGE
from forum.utils import get_group_ids_from_params
from forum.utils import get_commentable_ids_from_params, get_group_ids_from_params


class SearchThreadsView(APIView):
Expand Down Expand Up @@ -74,8 +74,7 @@ def _validate_and_extract_params(self, request: Request) -> dict[str, Any]:
# Group IDs extraction
params["group_ids"] = get_group_ids_from_params(data)

params["commentable_id"] = data.get("commentable_id")
params["commentable_ids"] = data.get("commentable_ids")
params["commentable_ids"] = get_commentable_ids_from_params(data)

return params

Expand Down