Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion forum/models/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ def insert(
if parent_id:
comment_data["parent_id"] = ObjectId(parent_id)

comment_data["endorsement"] = None

result = self._collection.insert_one(comment_data)
if parent_id:
self.update_child_count_in_parent_comment(parent_id, 1)
if comment_thread_id:
self.update_comment_count_in_comment_thread(comment_thread_id, 1)
self.update_sk(str(result.inserted_id), parent_id)
if result:
get_handler_by_name("comment_inserted").send(
sender=self.__class__, comment_id=str(result.inserted_id)
Expand Down Expand Up @@ -163,6 +166,7 @@ def update(
editing_user_id: Optional[str] = None,
edit_reason_code: Optional[str] = None,
endorsement_user_id: Optional[str] = None,
sk: Optional[str] = None,
) -> int:
"""
Updates a comment document in the database.
Expand Down Expand Up @@ -205,6 +209,7 @@ def update(
("child_count", child_count),
("depth", depth),
("closed", closed),
("sk", sk),
]
update_data: dict[str, Any] = {
field: value for field, value in fields if value is not None
Expand All @@ -221,6 +226,7 @@ def update(
edit_history = [] if edit_history is None else edit_history
edit_history.append(
{
"author_id": editing_user_id,
"original_body": original_body,
"reason_code": edit_reason_code,
"editor_username": self.get_author_username(editing_user_id),
Expand Down Expand Up @@ -330,5 +336,19 @@ def update_comment_count_in_comment_thread(
Returns:
None.
"""
update_comment_count_query = {"$inc": {"comment_count": count}}
update_comment_count_query = {
"$inc": {"comment_count": count},
"$set": {"last_activity_at": datetime.now()},
}
CommentThread().update_count(comment_thread_id, update_comment_count_query)

def get_sk(self, _id: str, parent_id: Optional[str]) -> str:
"""Returns sk field."""
if parent_id is not None:
return f"{parent_id}-{_id}"
return f"{_id}"

def update_sk(self, _id: str, parent_id: Optional[str]) -> None:
"""Updates sk field."""
sk = self.get_sk(_id, parent_id)
self.update(_id, sk=sk)
6 changes: 5 additions & 1 deletion forum/models/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def list(self, **kwargs: Any) -> Any:
"""
if self.content_type:
kwargs["_type"] = self.content_type
return self._collection.find(kwargs)
result = self._collection.find(kwargs)
sort = kwargs.pop("sort", None)
if sort:
return result.sort("sk", sort)
return result

@classmethod
def get_votes_dict(cls, up: List[str], down: List[str]) -> dict[str, Any]:
Expand Down
143 changes: 126 additions & 17 deletions forum/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from bson import ObjectId
from django.core.exceptions import ObjectDoesNotExist
from rest_framework import status
from rest_framework.response import Response

from forum.models import Comment, CommentThread, Contents, Subscriptions, Users

Expand Down Expand Up @@ -332,23 +334,23 @@ def get_read_states(
whether the thread is read and the unread comment count.
"""
read_states = {}
user = Users().find_one({"_id": user_id, "read_states.course_id": course_id})
read_state = user["read_states"][0] if user else {}

if read_state:
read_dates = read_state.get("last_read_times", {})
for thread in threads:
thread_key = str(thread["_id"])
if thread_key in read_dates:
is_read = read_dates[thread_key] >= thread["last_activity_at"]
unread_comment_count = Contents().count_documents(
{
"comment_thread_id": ObjectId(thread_key),
"created_at": {"$gte": read_dates[thread_key]},
"author_id": {"$ne": str(user_id)},
}
)
read_states[thread_key] = [is_read, unread_comment_count]
if user_id:
user = Users().find_one({"_id": user_id, "read_states.course_id": course_id})
read_state = user["read_states"][0] if user else {}
if read_state:
read_dates = read_state.get("last_read_times", {})
for thread in threads:
thread_key = str(thread["_id"])
if thread_key in read_dates:
is_read = read_dates[thread_key] >= thread["last_activity_at"]
unread_comment_count = Contents().count_documents(
{
"comment_thread_id": ObjectId(thread_key),
"created_at": {"$gte": read_dates[thread_key]},
"author_id": {"$ne": str(user_id)},
}
)
read_states[thread_key] = [is_read, unread_comment_count]

return read_states

Expand Down Expand Up @@ -776,3 +778,110 @@ def subscribe_user(
def unsubscribe_user(user_id: str, source_id: str) -> None:
"""Unsubscribe a user from a source."""
Subscriptions().delete_subscription(user_id, source_id)


def delete_comments_of_a_thread(thread_id: str) -> None:
"""Delete comments of a thread."""
for comment in Comment().list(
comment_thread_id=ObjectId(thread_id),
depth=0,
parent_id=None,
):
Comment().delete(comment["_id"])


def validate_params(
params: dict[str, Any], user_id: Optional[str] = None
) -> Response | None:
"""
Validate the request parameters.

Args:
params (dict): The request parameters.
user_id (optional[str]): The Id of the user for validation.

Returns:
Response: A Response object with an error message if doesn't exist.
"""
valid_params = [
"course_id",
"author_id",
"thread_type",
"flagged",
"unread",
"unanswered",
"unresponded",
"count_flagged",
"sort_key",
"page",
"per_page",
"request_id",
]
if not user_id:
valid_params.append("user_id")
if "user_id" not in params:
return Response(
{"error": "Missing required parameter: user_id"},
status=status.HTTP_400_BAD_REQUEST,
)
user_id = params.get("user_id")

for key in params:
if key not in valid_params:
return Response(
{"error": f"Invalid parameter: {key}"},
status=status.HTTP_400_BAD_REQUEST,
)

if "course_id" not in params:
return Response(
{"error": "Missing required parameter: course_id"},
status=status.HTTP_400_BAD_REQUEST,
)

if user_id:
user = Users().get(user_id)
if not user:
return Response(
{"error": "User doesn't exist"},
status=status.HTTP_400_BAD_REQUEST,
)

return None


def get_threads(
params: dict[str, Any],
user_id: str,
serializer: Any,
thread_ids: list[str],
include_context: Optional[bool] = False,
) -> dict[str, Any]:
"""get subscribed or all threads of a specific course for a specific user."""
threads = handle_threads_query(
thread_ids,
user_id,
params["course_id"],
get_group_ids_from_params(params),
params.get("author_id", ""),
params.get("thread_type"),
bool(params.get("flagged", False)),
bool(params.get("unread", False)),
bool(params.get("unanswered", False)),
bool(params.get("unresponded", False)),
bool(params.get("count_flagged", False)),
params.get("sort_key", ""),
int(params.get("page", 1)),
int(params.get("per_page", 100)),
)
context: dict[str, Any] = {}
if include_context:
context = {
"include_endorsed": True,
"include_read_state": True,
}
if user_id:
context["user_id"] = user_id
serializer = serializer(threads.pop("collection"), many=True, context=context)
threads["collection"] = serializer.data
return threads
26 changes: 24 additions & 2 deletions forum/models/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bson import ObjectId

from forum.models.contents import BaseContents
from forum.models.users import Users
from forum.utils import get_handler_by_name


Expand Down Expand Up @@ -89,7 +90,7 @@ def insert(
course_id: str,
commentable_id: str,
author_id: str,
author_username: str,
author_username: Optional[str] = None,
anonymous: bool = False,
anonymous_to_peers: bool = False,
thread_type: str = "discussion",
Expand Down Expand Up @@ -150,7 +151,7 @@ def insert(
"anonymous_to_peers": anonymous_to_peers,
"closed": False,
"author_id": author_id,
"author_username": author_username,
"author_username": author_username or self.get_author_username(author_id),
"created_at": date,
"updated_at": date,
"last_activity_at": date,
Expand Down Expand Up @@ -188,6 +189,10 @@ def update(
pinned: Optional[bool] = None,
comments_count: Optional[int] = None,
endorsed: Optional[bool] = None,
edit_history: Optional[list[dict[str, Any]]] = None,
original_body: Optional[str] = None,
editing_user_id: Optional[str] = None,
edit_reason_code: Optional[str] = None,
) -> int:
"""
Updates a thread document in the database.
Expand Down Expand Up @@ -241,6 +246,18 @@ def update(
update_data: dict[str, Any] = {
field: value for field, value in fields if value is not None
}
if editing_user_id:
edit_history = [] if edit_history is None else edit_history
edit_history.append(
{
"author_id": editing_user_id,
"original_body": original_body,
"reason_code": edit_reason_code,
"editor_username": self.get_author_username(editing_user_id),
"created_at": datetime.now(),
}
)
update_data["edit_history"] = edit_history

date = datetime.now()
update_data["updated_at"] = date
Expand All @@ -254,3 +271,8 @@ def update(
sender=self.__class__, comment_thread_id=thread_id
)
return result.modified_count

def get_author_username(self, author_id: str) -> str | None:
"""Return username for the respective author_id(user_id)"""
user = Users().get(author_id)
return user.get("username") if user else None
39 changes: 27 additions & 12 deletions forum/serializers/comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

from typing import Any

from bson import ObjectId
from rest_framework import serializers

from forum.models import Comment
from forum.serializers.contents import ContentSerializer
from forum.serializers.custom_datetime import CustomDateTimeField
from forum.utils import prepare_comment_data_for_get_children


class EndorsementSerializer(serializers.Serializer[dict[str, Any]]):
Expand Down Expand Up @@ -49,11 +52,12 @@ class CommentSerializer(ContentSerializer):

endorsed = serializers.BooleanField(default=False)
depth = serializers.IntegerField(default=0)
thread_id = serializers.CharField()
thread_id = serializers.CharField(source="comment_thread_id")
parent_id = serializers.CharField(default=None, allow_null=True)
child_count = serializers.IntegerField(default=0)
sk = serializers.SerializerMethodField()
sk = serializers.CharField(default=None, required=False, allow_null=True)
endorsement = EndorsementSerializer(default=None, required=False, allow_null=True)
children = serializers.SerializerMethodField()

def __init__(self, *args: Any, **kwargs: Any) -> None:
exclude_fields = kwargs.pop("exclude_fields", None)
Expand All @@ -62,22 +66,33 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
for field in exclude_fields:
self.fields.pop(field, None)

def get_children(self, obj: Any) -> list[dict[str, Any]]:
"""Get comments of a thread."""
if not self.context.get("recursive", False):
return []

children = list(
Comment().list(
parent_id=ObjectId(obj["_id"]),
depth=1,
sort=self.context.get("sort", -1),
)
)
children_data = prepare_comment_data_for_get_children(children)
serializer = CommentSerializer(
children_data,
many=True,
context={"recursive": False},
exclude_fields=["sk"],
)
return list(serializer.data)

def to_representation(self, instance: Any) -> dict[str, Any]:
comment = super().to_representation(instance)
if comment["parent_id"] == "None":
comment["parent_id"] = None
return comment

def get_sk(self, obj: dict[str, Any]) -> str:
"""Return sk field"""
is_child = obj.get("parent_id")
if is_child is not None:
return "{parent_id}-{id}".format(
parent_id=obj.get("parent_id"), id=obj.get("_id")
)
else:
return "{id}".format(id=obj.get("_id"))

def create(self, validated_data: dict[str, Any]) -> Any:
"""Raise NotImplementedError"""
raise NotImplementedError
Expand Down
Loading