Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions forum/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,39 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Optional

from bson import ObjectId
from pymongo.collection import Collection
from pymongo.collection import Collection as PymongoCollection
from pymongo.cursor import Cursor

from forum.mongo import MongoBackend
from forum.mongo import Database, get_database

Collection = PymongoCollection[dict[str, Any]]


class MongoBaseModel(ABC):
"""Abstract Class for Mongo model implementation"""

def __init__(
self,
collection_name: Optional[str] = None,
client: Optional[MongoBackend] = None,
) -> None:
self.client: MongoBackend = client or MongoBackend(collection=collection_name)
MONGODB_DATABASE: Optional[Database] = None
COLLECTION_NAME: str = "default"

@property
def _collection(self) -> Collection[Dict[str, Any]]:
"""Get mongo db collection"""
return self.get_client.collection
def _collection(self) -> Collection:
return self.__get_database()[self.COLLECTION_NAME]

@property
def get_client(self) -> MongoBackend:
"""Get mongo client"""
return self.client
@classmethod
def __get_database(cls) -> Database:
"""Get or create static class database."""
if cls.MONGODB_DATABASE is None:
cls.MONGODB_DATABASE = get_database()
return cls.MONGODB_DATABASE

def get(self, _id: str) -> Optional[Dict[str, Any]]:
def get(self, _id: str) -> Optional[dict[str, Any]]:
"""Get a document by filter"""
return self._collection.find_one({"_id": _id})

def list(self, **kwargs: Any) -> Cursor[Dict[str, Any]]:
def list(self, **kwargs: Any) -> Cursor[dict[str, Any]]:
"""Get a list of all documents filtered by kwargs"""
return self._collection.find(kwargs)

Expand Down
14 changes: 1 addition & 13 deletions forum/models/contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from bson import ObjectId

from forum.models.base_model import MongoBaseModel
from forum.mongo import MongoBackend


class Contents(MongoBaseModel):
Expand All @@ -17,18 +16,7 @@ class Contents(MongoBaseModel):
"""

content_type: str = ""

def __init__(
self, collection_name: str = "contents", client: Optional[MongoBackend] = None
) -> None:
"""
Initializes the Content class.

Args:
collection_name: The name of the MongoDB collection.
client: The MongoDB client.
"""
super().__init__(collection_name, client)
COLLECTION_NAME: str = "contents"

def get(
self, _id: str
Expand Down
14 changes: 1 addition & 13 deletions forum/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,14 @@
from typing import Any, Dict, List, Optional

from forum.models.base_model import MongoBaseModel
from forum.mongo import MongoBackend


class Users(MongoBaseModel):
"""
Users class for cs_comments_service user model
"""

def __init__(
self, collection_name: str = "users", client: Optional[MongoBackend] = None
) -> None:
"""
Initializes the Users class.

Args:
collection_name: The name of the MongoDB collection.
client: The MongoDB client.

"""
super().__init__(collection_name, client)
COLLECTION_NAME: str = "users"

def insert(
self,
Expand Down
89 changes: 38 additions & 51 deletions forum/mongo.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,47 @@
"""Mongo module for forum app."""

import logging
from typing import Any, Dict
from typing import Any, Optional

from django.conf import settings
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.database import Database as PymongoDatabase

log = logging.getLogger(__name__)


class MongoBackend:
"""Class for mongoDB cs_comments_service backend."""

def __init__(self, **kwargs: Any) -> None:
"""
Connect to MongoDB.

:Parameters:

- `host`: hostname
- `port`: port
- `user`: collection username
- `password`: collection user password
- `database`: name of the database
- `collection`: name of the collection
- `authsource`: name of the authentication database
- `extra`: parameters to pymongo.MongoClient not listed above

"""
# Extract connection parameters from kwargs

host = kwargs.get("host", settings.MONGO_HOST)
port = kwargs.get("port", settings.MONGO_PORT)

user = kwargs.get("user", "")
password = kwargs.get("password", "")

db_name = kwargs.get("database", "cs_comments_service")
collection_name = kwargs.get("collection", "")

auth_source = kwargs.get("authsource") or None

# Other mongo connection arguments
extra = kwargs.get("extra", {})

# Make timezone aware by default
extra["tz_aware"] = extra.get("tz_aware", True)

# Connect to database and get collection

self.connection: MongoClient[Any] = MongoClient(host=host, port=port, **extra)

database = self.connection[db_name]

if user or password:
database.authenticate(user, password, source=auth_source)

self.collection: Collection[Dict[str, Any]] = database[collection_name]
Database = PymongoDatabase[dict[str, Any]]


def get_database(
host: str = settings.MONGO_HOST,
port: int = settings.MONGO_PORT,
user: str = "",
password: str = "",
database: str = "cs_comments_service",
authsource: Optional[str] = None,
tz_aware: bool = True,
**extra: Any
) -> Database:
"""
Connect to MongoDB.

:Parameters:

- `host`: hostname
- `port`: port
- `user`: collection username
- `password`: collection user password
- `database`: name of the database
- `authsource`: name of the authentication database
- `extra`: parameters to pymongo.MongoClient not listed above

"""
connection: MongoClient[Any] = MongoClient(
host=host, port=port, tz_aware=tz_aware, **extra
)
db = connection[database]

if user or password:
db.authenticate(user, password, source=authsource)

return db
59 changes: 7 additions & 52 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,22 @@

import json
from typing import Any, Generator, Union
from unittest.mock import MagicMock, patch

import mongomock
import pytest
from django.http.response import HttpResponse
from django.test import Client
from pymongo import MongoClient

from forum.models import Comment, CommentThread, Contents, Users
from forum.mongo import MongoBackend


@pytest.fixture(name="mock_mongo_backend")
def fixture_mock_mongo_backend() -> MagicMock:
"""Mock MongoClient for tests."""
@pytest.fixture(autouse=True)
def patch_default_mongo_database(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock default mongodb database for tests."""
client: MongoClient[Any] = mongomock.MongoClient()
db = client["test_forum_db"]

collections = {
"contents": db["contents"],
"users": db["users"],
}

mock_backend: MagicMock = MagicMock(spec=MongoBackend)
for name, collection in collections.items():
setattr(mock_backend, name, collection)

return mock_backend


@pytest.fixture(name="patch_mongo_backend")
def fixture_patch_mongo_backend(
mock_mongo_backend: MagicMock,
) -> Generator[MagicMock, None, None]:
"""Patch the MongoBackend instance with a mock."""
with patch("forum.mongo.MongoBackend", return_value=mock_mongo_backend):
yield mock_mongo_backend


@pytest.fixture(name="users_model")
def fixture_users_model(patch_mongo_backend: MagicMock) -> Users:
"""Get Users model with patched backend."""
return Users(client=patch_mongo_backend.users)


@pytest.fixture(name="comment_thread_model")
def fixture_comment_thread_model(patch_mongo_backend: MagicMock) -> CommentThread:
"""Get CommentThread model with patched backend."""
return CommentThread(client=patch_mongo_backend.contents)


@pytest.fixture(name="comment_model")
def fixture_comment_model(patch_mongo_backend: MagicMock) -> Comment:
"""Get Comment model with patched backend."""
return Comment(client=patch_mongo_backend.contents)


@pytest.fixture(name="content_model")
def fixture_content_model(patch_mongo_backend: MagicMock) -> Contents:
"""Get Contents model with patched backend."""
return Contents(client=patch_mongo_backend.contents)
monkeypatch.setattr(
"forum.models.base_model.MongoBaseModel.MONGODB_DATABASE",
client["test_forum_db"],
)


class APIClient(Client):
Expand Down
32 changes: 16 additions & 16 deletions tests/test_models/test_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,72 @@
from forum.models import Comment


def test_insert(comment_model: Comment) -> None:
def test_insert() -> None:
"""Test insert a comment into MongoDB."""
comment_id = comment_model.insert(
comment_id = Comment().insert(
body="<p>This is a test comment</p>",
course_id="course1",
comment_thread_id="66af33634a1e1f001b7ed57f",
author_id="author1",
author_username="author_user",
)
assert comment_id is not None
comment_data = comment_model.get(_id=comment_id)
comment_data = Comment().get(_id=comment_id)
assert comment_data is not None
assert comment_data["body"] == "<p>This is a test comment</p>"


def test_delete(comment_model: Comment) -> None:
def test_delete() -> None:
"""Test delete a comment from MongoDB."""
comment_id = comment_model.insert(
comment_id = Comment().insert(
body="<p>This is a test comment</p>",
course_id="course1",
comment_thread_id="66af33634a1e1f001b7ed57f",
author_id="author1",
author_username="author_user",
)
result = comment_model.delete(comment_id)
result = Comment().delete(comment_id)
assert result == 1
comment_data = comment_model.get(_id=comment_id)
comment_data = Comment().get(_id=comment_id)
assert comment_data is None


def test_list(comment_model: Comment) -> None:
def test_list() -> None:
"""Test list all comments from MongoDB."""
course_id = "course-xyz"
thread_id = "66af33634a1e1f001b7ed57f"
author_id = "4"
author_username = "edly"

comment_model.insert(
Comment().insert(
"<p>Comment 1</p>", course_id, thread_id, author_id, author_username
)
comment_model.insert(
Comment().insert(
"<p>Comment 2</p>", course_id, thread_id, author_id, author_username
)
comment_model.insert(
Comment().insert(
"<p>Comment 3</p>", course_id, thread_id, author_id, author_username
)

comments_list = comment_model.list()
comments_list = Comment().list()
assert len(list(comments_list)) == 3
assert all(comment["body"].startswith("<p>Comment") for comment in comments_list)


def test_update(comment_model: Comment) -> None:
def test_update() -> None:
"""Test update a comment in MongoDB."""
comment_id = comment_model.insert(
comment_id = Comment().insert(
body="<p>This is a test comment</p>",
course_id="course1",
comment_thread_id="66af33634a1e1f001b7ed57f",
author_id="author1",
author_username="author_user",
)

result = comment_model.update(
result = Comment().update(
comment_id=comment_id,
body="<p>Updated comment</p>",
)
assert result == 1
comment_data = comment_model.get(_id=comment_id) or {}
comment_data = Comment().get(_id=comment_id) or {}
assert comment_data.get("body", "") == "<p>Updated comment</p>"
Loading