diff --git a/fineladybot/database.py b/fineladybot/database.py index 867f274..1b4ad97 100644 --- a/fineladybot/database.py +++ b/fineladybot/database.py @@ -9,15 +9,20 @@ class Database: - def __init__(self, name: str, logger: Logger) -> None: + def __init__(self, name: str, logger: Logger, testing=False) -> None: self.name = name self.logger = logger + self.testing = testing + self.connection = self.db_connection() self.initialise_db() def db_connection(self) -> sqlite3.Connection: con: sqlite3.Connection try: - con = sqlite3.connect(str(filepath.joinpath(self.name)) + ".db") + if self.testing: + con = sqlite3.connect(":memory:") + else: + con = sqlite3.connect(str(filepath.joinpath(self.name)) + ".db") # self.logger.info("Connected to finelady.db\nSQLite3 version %s", sqlite3.version) except Error as e: self.logger.info(e) @@ -44,7 +49,7 @@ def initialise_db(self) -> None: self._create_table(query) def _create_table(self, sql: str, args: Optional[str] = None) -> None: - with self.db_connection() as con: + with self.connection as con: cur = con.cursor() if args: cur.execute(sql, args) @@ -54,20 +59,20 @@ def _create_table(self, sql: str, args: Optional[str] = None) -> None: def add_opt_out_user(self, name: str, request_date: datetime) -> None: sql = """INSERT OR IGNORE INTO opt_out_users(username, request_date) VALUES(?,?);""" - with self.db_connection() as con: + with self.connection as con: cur = con.cursor() cur.execute(sql, (name, request_date)) def add_opt_out_sub(self, subreddit: str, requestor: str, request_date: datetime) -> None: sql = """INSERT OR IGNORE INTO opt_out_subs(subreddit, requestor, request_date) VALUES(?,?,?);""" - with self.db_connection() as con: + with self.connection as con: cur = con.cursor() cur.execute(sql, (subreddit, requestor, request_date)) def query_users(self) -> list[str]: sql = """SELECT username FROM opt_out_users""" - with self.db_connection() as con: + with self.connection as con: cur = con.cursor() cur.execute(sql) users = cur.fetchall() @@ -75,7 +80,7 @@ def query_users(self) -> list[str]: def query_subs(self) -> list[str]: sql = """SELECT subreddit FROM opt_out_subs""" - with self.db_connection() as con: + with self.connection as con: cur = con.cursor() cur.execute(sql) subs = cur.fetchall() diff --git a/fineladybot/finelady.py b/fineladybot/finelady.py index e6e4d41..dfc350a 100755 --- a/fineladybot/finelady.py +++ b/fineladybot/finelady.py @@ -70,8 +70,10 @@ def run() -> None: if "user_opt_out" in message.subject: message_date = datetime.fromtimestamp(message.created_utc) db.add_opt_out_user(message.author.name, message_date) + opt_out_list.append(message.author.name) if "sub_opt_out" in message.subject: - parse_sub_opt_out(message, reddit) + subreddit = parse_sub_opt_out(message, reddit) + sub_opt_out_list.append(subreddit.display_name) def get_opt_out_url() -> str: @@ -125,7 +127,7 @@ def crosspost_submission(submission: Submission) -> None: return crosspost -def parse_sub_opt_out(message: Message, reddit: praw.Reddit) -> None: +def parse_sub_opt_out(message: Message, reddit: praw.Reddit, db: Database = db) -> None: """Parse a request to opt out from being cross-posted by a sub moderator. Add the request to the database.""" message_date = datetime.fromtimestamp(message.created_utc) @@ -140,3 +142,4 @@ def parse_sub_opt_out(message: Message, reddit: praw.Reddit) -> None: from_mod = message.author.name in subreddit_moderators if from_mod: db.add_opt_out_sub(subreddit.display_name, message.author.name, message_date) + return subreddit diff --git a/tests/conftest.py b/tests/conftest.py index b3d722c..2963022 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,27 @@ import pytest +from logging import Logger +from unittest.mock import Mock, MagicMock + import praw.models # type: ignore +from fineladybot.database import Database + +mock_logger = Mock(spec=Logger) + + +@pytest.fixture +def reddit() -> praw.Reddit: + reddit = MagicMock(spec=praw.Reddit) + reddit.subreddit = praw.models.SubredditHelper(reddit, None) + return reddit + @pytest.fixture -def subreddit() -> praw.models.Subreddit: - return praw.models.Subreddit(None, "testsub") +def subreddit(reddit) -> praw.models.Subreddit: + return praw.models.Subreddit(reddit, "testsub") + + +@pytest.fixture(scope="session") +def db(): + database = Database("test_db", mock_logger, True) + return database diff --git a/tests/test_opt_out.py b/tests/test_opt_out.py new file mode 100644 index 0000000..78c9957 --- /dev/null +++ b/tests/test_opt_out.py @@ -0,0 +1,61 @@ +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock +from urllib.parse import quote_plus + +import pytest +from praw.models import Message, Subreddit, User +from praw import Reddit + +from fineladybot.database import Database +from fineladybot.finelady import parse_sub_opt_out + + +@pytest.mark.parametrize( + "user,date", + [ + ("tim", datetime(2022, 1, 1)), + ], +) +def test_user_opt_out(user, date, db: Database): + db.add_opt_out_user(user, date) + opted_out_users = db.query_users() + assert opted_out_users[0] == user + + +def mock_moderator_method_factory(list_of_mods: list[str]) -> callable: + def mock_moderator_method(self) -> list[str]: + return list_of_mods + + return mock_moderator_method + + +@pytest.mark.parametrize( + "author, mods, expected", + [ + ("anon", ["mod1"], False), # author is not mod + ("mod1", ["mod1"], True), # author is mod + ], +) +def test_sub_opt_out( + author: str, + mods: list[str], + expected: list[str], + monkeypatch: Any, + reddit: Reddit, + subreddit: Subreddit, + db: Database, +): + mock_moderator_method = mock_moderator_method_factory(mods) + monkeypatch.setattr(Subreddit, "moderator", mock_moderator_method) + msg = MagicMock(spec=Message) + user = MagicMock(spec=User) + msg.created_utc = datetime(2022, 1, 1).timestamp() + msg.body = f"Please do not post anything further to /r/{subreddit}" + msg.author = user + msg.author.name = author + opted_out_subs = [] + parse_sub_opt_out(msg, reddit, db) + opted_out_subs = db.query_subs() + result = opted_out_subs == [subreddit.display_name] + assert result == expected