Skip to content

Commit

Permalink
Refactor the code (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikelane committed May 16, 2022
1 parent 18ea7d2 commit 48456e5
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 79 deletions.
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[run]
omit =
reddit_get/types/__init__.py

15 changes: 15 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
coverage:
status:
project:
default:
target: 90%
threshold: 5%
patch:
target: 90%
threshold: 5%

ignore:
- "reddit_get/__init__.py"
- "reddit_get/types/__init__.py"
- "tests/"

16 changes: 14 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ fire = '>=0.3.1,<0.5.0'
praw = "^7.6.0"
toml = '^0.10.2'
titlecase = "^2.3.0"
typing-extensions = "^4.2.0"

[tool.poetry.dev-dependencies]
black = { version = '*', allow-prereleases = true }
Expand All @@ -42,6 +43,7 @@ isort = "*"
pytest-isort = "*"
pydantic = "^1.9.0"
pytest-mypy = {version = "*", allow-prereleases = true}
types-toml = "^0.10.7"

[tool.pytest.ini_options]
minversion = '6.0'
Expand Down Expand Up @@ -73,6 +75,7 @@ exclude = '''
[tool.isort]
profile = "black"
multi_line_output = 3
force_grid_wrap = 2

[build-system]
requires = ['poetry-core>=1.0.0']
Expand Down
118 changes: 43 additions & 75 deletions reddit_get/cli.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import functools
import sys
from pathlib import Path
from string import Formatter
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Union
from typing import (
Dict,
List,
Union,
)

import fire
import praw
import toml
from praw.exceptions import MissingRequiredAttributeException
from praw.models import Submission
from praw.models.reddit.subreddit import Subreddit

from reddit_get.types import SortingOption, TimeFilterOption
from .types import (
SortingOption,
TimeFilterOption,
)
from .utils import (
create_post_output,
get_post_sorting_option,
get_reddit_query_function,
get_response,
get_template_keys,
get_time_filter_option,
load_configs,
)


class RedditCli:
Expand All @@ -38,19 +48,9 @@ class RedditCli:
"""

def __init__(self, config: str = '~/.redditgetrc'):
self.config_path: Path = Path(config).expanduser()
try:
self.configs = toml.load(self.config_path)
except (FileNotFoundError, toml.TomlDecodeError):
raise fire.core.FireError(f'No valid TOML config found at {self.config_path}')
try:
self.reddit = praw.Reddit(**self.configs['reddit-get'])
except MissingRequiredAttributeException as e: # pragma: no cover
fire.core.FireError(e)
if not self.reddit.user.me():
raise fire.core.FireError( # pragma: no cover
'Failed to authenticate with Reddit. Did you remember your username and password?'
)
self.config_path, self.configs = load_configs(config)
self.reddit = self.get_authenticated_reddit_instance()

self.valid_header_variables: Dict[str, Dict[Union[SortingOption, TimeFilterOption], str]] = {
'sorting': {
SortingOption.CONTROVERSIAL: 'Most Controversial',
Expand All @@ -71,6 +71,17 @@ def __init__(self, config: str = '~/.redditgetrc'):
},
}

def get_authenticated_reddit_instance(self):
try:
reddit = praw.Reddit(**self.configs['reddit-get'])
if not reddit.user.me():
raise fire.core.FireError( # pragma: no cover
'Failed to authenticate with Reddit. Did you remember your username and password?'
)
return reddit
except MissingRequiredAttributeException as e: # pragma: no cover
fire.core.FireError(e)

def config_location(self):
"""Get the path of the reddit-get config.
Expand All @@ -81,11 +92,11 @@ def config_location(self):
else:
raise fire.core.FireError(f'No config_path has been set!')

def _create_header(
def create_header(
self, template: str, sorting: SortingOption, time: TimeFilterOption, subreddit: str
) -> str:
valid_keys = {'sorting', 'time', 'subreddit'}
keys = self._get_template_keys(template)
keys = get_template_keys(template)
if keys and not keys.issubset(valid_keys):
raise fire.core.FireError(
f'Invalid keys passed into header template: {", ".join(keys - valid_keys)}'
Expand All @@ -97,24 +108,6 @@ def _create_header(
}
return template.format(**format_params)

def _create_post_output(self, template: str, posts: Iterator[Submission]) -> List[str]:
template_vars = self._get_template_keys(template)
if not template_vars:
raise fire.core.FireError('Your post output template did not have any items to be printed')
results = []
for post in posts:
try:
format_params = {key: getattr(post, key) for key in template_vars}
results.append(template.format(**format_params))
except AttributeError as e:
raise fire.core.FireError(e)
return results

@staticmethod
def _get_template_keys(template: str) -> Optional[Set[str]]:
template_vars = {tup[1] for tup in Formatter().parse(template) if tup[1] and isinstance(tup[1], str)}
return template_vars or None

def post(
self,
subreddit: str,
Expand Down Expand Up @@ -178,45 +171,20 @@ def post(
The number of post titles from the specified subreddit
formatted as specified
"""
try:
post_sorting = SortingOption(post_sorting)
except ValueError:
raise fire.core.FireError(f'{post_sorting} is not a valid sorting option.')
try:
time_filter = TimeFilterOption(time_filter)
except ValueError:
raise fire.core.FireError(f'{time_filter} is not a valid time filter option')
if not 0 < limit <= 25:
raise fire.core.FireError('You may only get between 1 and 25 submissions')

praw_subreddit: Subreddit = self.reddit.subreddit(subreddit)

call_map: Dict[SortingOption, Callable[[Optional[int]], Iterator[Any]]] = {
SortingOption.CONTROVERSIAL: functools.partial(
praw_subreddit.controversial, time_filter=time_filter
sorting = get_post_sorting_option(post_sorting)
query_fn = get_reddit_query_function(self.reddit.subreddit(subreddit), time_filter, sorting)
return get_response(
self.create_header(
template=custom_header,
sorting=sorting,
time=get_time_filter_option(time_filter),
subreddit=subreddit,
),
SortingOption.GILDED: praw_subreddit.gilded,
SortingOption.HOT: praw_subreddit.hot,
SortingOption.NEW: praw_subreddit.new,
SortingOption.RANDOM_RISING: praw_subreddit.random_rising,
SortingOption.RISING: praw_subreddit.rising,
SortingOption.TOP: functools.partial(praw_subreddit.top, time_filter=time_filter),
}

response_header = (
[
self._create_header(
template=custom_header, sorting=post_sorting, time=time_filter, subreddit=subreddit
)
]
if header
else []
create_post_output(output_format, query_fn(limit=limit)),
)

posts: List[str] = self._create_post_output(output_format, call_map[post_sorting](limit=limit)) # type: ignore

return response_header + posts


def main(): # pragma: no cover
try:
Expand Down
22 changes: 22 additions & 0 deletions reddit_get/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,23 @@
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
)

try:
from typing import Protocol
except ImportError: # pragma: no cover
from typing_extensions import Protocol # type: ignore

from .enums import *


class PrawQuery(Protocol): # pragma: no cover
def __call__(self, limit: Optional[int]) -> Iterator[Any]:
...


CallMap = Dict[SortingOption, PrawQuery]
Posts = List[str]
5 changes: 4 additions & 1 deletion reddit_get/types/enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from enum import Enum, EnumMeta
from enum import (
Enum,
EnumMeta,
)


class MetaEnum(EnumMeta):
Expand Down
89 changes: 89 additions & 0 deletions reddit_get/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import functools
from pathlib import Path
from string import Formatter
from typing import (
Iterator,
List,
Optional,
Set,
)

import fire
import toml
from praw.models import (
Submission,
Subreddit,
)

from .types import (
CallMap,
PrawQuery,
SortingOption,
TimeFilterOption,
)


def load_configs(config):
config_path: Path = Path(config).expanduser()
try:
configs = toml.load(config_path)
except (FileNotFoundError, toml.TomlDecodeError):
raise fire.core.FireError(f'No valid TOML config found at {config_path}')
return config_path, configs


def get_reddit_query_function(
subreddit: Subreddit, time_filter: str = 'all', post_sorting: SortingOption = SortingOption.TOP
) -> PrawQuery:
call_map: CallMap = {
SortingOption.CONTROVERSIAL: functools.partial(subreddit.controversial, time_filter=time_filter),
SortingOption.GILDED: subreddit.gilded,
SortingOption.HOT: subreddit.hot,
SortingOption.NEW: subreddit.new,
SortingOption.RANDOM_RISING: subreddit.random_rising,
SortingOption.RISING: subreddit.rising,
SortingOption.TOP: functools.partial(subreddit.top, time_filter=time_filter),
}
try:
return call_map[post_sorting]
except KeyError:
raise fire.core.FireError(f'Invalid sorting option: {post_sorting}')


def get_response(header: str, posts: List[str]) -> List[str]:
response_header = [header] if header else []
return response_header + posts


def get_time_filter_option(time_filter):
try:
time_filter = TimeFilterOption(time_filter)
except ValueError:
raise fire.core.FireError(f'{time_filter} is not a valid time filter option')
return time_filter


def get_post_sorting_option(post_sorting: str) -> SortingOption:
try:
return SortingOption(post_sorting)
except ValueError:
raise fire.core.FireError(f'{post_sorting} is not a valid sorting option.')


def get_template_keys(template: str) -> Optional[Set[str]]:
template_vars = {tup[1] for tup in Formatter().parse(template) if tup[1] and isinstance(tup[1], str)}
return template_vars or None


def create_post_output(template: str, posts: Iterator[Submission]) -> List[str]:
template_vars = get_template_keys(template)
if not template_vars:
raise fire.core.FireError('Your post output template did not have any items to be printed')
results = []
for post in posts:
try:
format_params = {key: getattr(post, key) for key in template_vars}
results.append(template.format(**format_params))
except AttributeError as e:
raise fire.core.FireError(e)
return results
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MockSubreddit:
def __init__(self, display_name: str, *args, **kwargs):
self.display_name = display_name

def __repr__(self):
def __repr__(self): # pragma: nocover
return self.display_name

def controversial(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit 48456e5

Please sign in to comment.