Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Let the fix tool changes Python files except data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
boholder committed Dec 8, 2022
1 parent 7ba9738 commit 1e85bb9
Show file tree
Hide file tree
Showing 16 changed files with 111 additions and 83 deletions.
54 changes: 28 additions & 26 deletions puntgun/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import datetime
import functools
import itertools
from enum import Enum
from typing import Any, Callable, Iterator, List, TypeVar
from typing import Any, Callable, Iterator, TypeVar

import tweepy
from loguru import logger
Expand Down Expand Up @@ -34,7 +36,7 @@ def __init__(self, title: str, ref_url: str, detail: str, parameter: str, value:
self.value = value

@staticmethod
def from_response(resp_error: dict) -> "TwitterApiError":
def from_response(resp_error: dict) -> TwitterApiError:
# build an accurate error type according to the response content
singleton_iter = (c for c in TwitterApiError.__subclasses__() if c.title == resp_error.get("title"))
# if we haven't written a subclass for given error, return the generic error (in default value way)
Expand Down Expand Up @@ -73,7 +75,7 @@ class TwitterApiErrors(Exception, Recordable):
contains a list of :class:`TwitterApiError`.
"""

def __init__(self, query_func_name: str, query_params: dict, resp_errors: List[dict]):
def __init__(self, query_func_name: str, query_params: dict, resp_errors: list[dict]):
"""Details for debugging via log."""
self.query_func_name = query_func_name
self.query_params = query_params
Expand Down Expand Up @@ -109,7 +111,7 @@ def to_record(self) -> Record:
)

@staticmethod
def parse_from_record(record: Record) -> "TwitterApiErrors":
def parse_from_record(record: Record) -> TwitterApiErrors:
data = record.data
return TwitterApiErrors(data.get("query_func_name", ""), data.get("query_params", ()), data.get("errors", []))

Expand Down Expand Up @@ -270,7 +272,7 @@ def __init__(self, tweepy_client: tweepy.Client):

@staticmethod
@functools.lru_cache(maxsize=1)
def singleton() -> "Client":
def singleton() -> Client:
secrets = secret.load_or_request_all_secrets(encrypto.load_or_generate_private_key())
return Client(
tweepy.Client(
Expand All @@ -281,7 +283,7 @@ def singleton() -> "Client":
)
)

def get_users_by_usernames(self, names: List[str]) -> List[User]:
def get_users_by_usernames(self, names: list[str]) -> list[User]:
"""
Query users information.
**Rate limit: 900 / 15 min**
Expand All @@ -292,7 +294,7 @@ def get_users_by_usernames(self, names: List[str]) -> List[User]:

return response_to_users(self.clt.get_users(usernames=names, **USER_API_PARAMS))

def get_users_by_ids(self, ids: List[int | str]) -> List[User]:
def get_users_by_ids(self, ids: list[int | str]) -> list[User]:
"""
Query users information.
**Rate limit: 900 / 15 min**
Expand All @@ -303,7 +305,7 @@ def get_users_by_ids(self, ids: List[int | str]) -> List[User]:

return response_to_users(self.clt.get_users(ids=ids, **USER_API_PARAMS))

def get_blocked(self) -> List[User]:
def get_blocked(self) -> list[User]:
"""
Get the latest blocking list of the current account.
**Rate limit: 15 / 15 min**
Expand All @@ -312,7 +314,7 @@ def get_blocked(self) -> List[User]:
return query_paged_user_api(self.clt.get_blocked)

@functools.lru_cache(maxsize=1)
def cached_blocked(self) -> List[User]:
def cached_blocked(self) -> list[User]:
"""
Call query method, cache them, and return the cache on latter calls.
Since the tool may be constantly modifying the block list,
Expand All @@ -322,10 +324,10 @@ def cached_blocked(self) -> List[User]:
return self.get_blocked()

@functools.lru_cache(maxsize=1)
def cached_blocked_id_list(self) -> List[int]:
def cached_blocked_id_list(self) -> list[int]:
return [u.id for u in self.cached_blocked()]

def get_following(self, user_id: int | str) -> List[User]:
def get_following(self, user_id: int | str) -> list[User]:
"""
Get the latest following list of a user.
**Rate limit: 15 / 15 min**
Expand All @@ -334,14 +336,14 @@ def get_following(self, user_id: int | str) -> List[User]:
return query_paged_user_api(self.clt.get_users_following, id=user_id)

@functools.lru_cache(maxsize=1)
def cached_following(self) -> List[User]:
def cached_following(self) -> list[User]:
return self.get_following(self.id)

@functools.lru_cache(maxsize=1)
def cached_following_id_list(self) -> List[int]:
def cached_following_id_list(self) -> list[int]:
return [u.id for u in self.cached_following()]

def get_follower(self, user_id: int | str) -> List[User]:
def get_follower(self, user_id: int | str) -> list[User]:
"""
Get the latest follower list of a user.
**Rate limit: 15 / 15 min**
Expand All @@ -350,11 +352,11 @@ def get_follower(self, user_id: int | str) -> List[User]:
return query_paged_user_api(self.clt.get_users_followers, id=user_id)

@functools.lru_cache(maxsize=1)
def cached_follower(self) -> List[User]:
def cached_follower(self) -> list[User]:
return self.get_follower(self.id)

@functools.lru_cache(maxsize=1)
def cached_follower_id_list(self) -> List[int]:
def cached_follower_id_list(self) -> list[int]:
return [u.id for u in self.cached_follower()]

def block_user_by_id(self, target_user_id: int | str) -> bool:
Expand Down Expand Up @@ -383,7 +385,7 @@ def block_user_by_id(self, target_user_id: int | str) -> bool:
# call the block api
return self.clt.block(target_user_id=target_user_id).data["blocking"]

def get_tweets_by_ids(self, ids: List[int | str]) -> List[Tweet]:
def get_tweets_by_ids(self, ids: list[int | str]) -> list[Tweet]:
"""
Query tweets information.
**Rate limit: 900 / 15 min**
Expand All @@ -394,15 +396,15 @@ def get_tweets_by_ids(self, ids: List[int | str]) -> List[Tweet]:

return response_to_tweets(self.clt.get_tweets(ids=ids, **TWEET_API_PARAMS))

def get_users_who_like_tweet(self, tweet_id: int | str) -> List[User]:
def get_users_who_like_tweet(self, tweet_id: int | str) -> list[User]:
"""
Get a Tweet’s liking users (who liked this tweet).
**Rate limit: 75 / 15 min**
https://developer.twitter.com/en/docs/twitter-api/tweets/likes/api-reference/get-tweets-id-liking_users
"""
return query_paged_user_api(self.clt.get_liking_users, max_results=100, id=tweet_id)

def get_users_who_retweet_tweet(self, tweet_id: int | str) -> List[User]:
def get_users_who_retweet_tweet(self, tweet_id: int | str) -> list[User]:
"""
Get users who have retweeted a Tweet.
**Rate limit: 75 / 15 min**
Expand All @@ -419,7 +421,7 @@ def search_tweets(
end_time: datetime.datetime = None,
since_id: int = None,
until_id: int = None,
) -> List[Tweet]:
) -> list[Tweet]:
"""
Search tweets with a query string.
With Essential Twitter API access, the query length is limited up to 512 characters,
Expand All @@ -443,7 +445,7 @@ def search_tweets(
return []


def response_to_users(resp: tweepy.Response) -> List[User]:
def response_to_users(resp: tweepy.Response) -> list[User]:
"""Build a list of :class:`User` instances from one response."""
if not resp.data:
return []
Expand All @@ -459,7 +461,7 @@ def map_one(data: dict) -> User:
return [map_one(d) for d in resp.data]


def response_to_tweets(resp: tweepy.Response) -> List[Tweet]:
def response_to_tweets(resp: tweepy.Response) -> list[Tweet]:
"""Build a list of :class:`Tweet` instances from one response."""
if not resp.data:
return []
Expand Down Expand Up @@ -517,11 +519,11 @@ def map_one(data: dict) -> Tweet:
return [map_one(d) for d in resp.data]


def query_paged_user_api(clt_func: Callable[..., Response], max_results: int = 1000, **kwargs: Any) -> List[User]:
def query_paged_user_api(clt_func: Callable[..., Response], max_results: int = 1000, **kwargs: Any) -> list[User]:
return query_paged_entity_api(clt_func, USER_API_PARAMS, response_to_users, max_results=max_results, **kwargs)


def query_paged_tweet_api(clt_func: Callable[..., Response], times: int = None, **kwargs: Any) -> List[Tweet]:
def query_paged_tweet_api(clt_func: Callable[..., Response], times: int = None, **kwargs: Any) -> list[Tweet]:
return query_paged_entity_api(clt_func, TWEET_API_PARAMS, response_to_tweets, times, max_results=100, **kwargs)


Expand All @@ -531,11 +533,11 @@ def query_paged_tweet_api(clt_func: Callable[..., Response], times: int = None,
def query_paged_entity_api(
clt_func: Callable[..., Response],
api_params: dict,
transforming_func: Callable[[Response], List[E]],
transforming_func: Callable[[Response], list[E]],
times: int = None,
max_results: int = 100,
**kwargs: Any,
) -> List[E]:
) -> list[E]:
# mix two part of params into one dict
params = {}
for k, v in api_params.items():
Expand Down
2 changes: 2 additions & 0 deletions puntgun/commands.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""The implementation of commands"""
from __future__ import annotations

from pathlib import Path

from loguru import logger
Expand Down
6 changes: 4 additions & 2 deletions puntgun/conf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
IMPROVE: Any proper way to unit test this module?
I feel it is too implement-coupling to be valuable enough writing test cases.
"""
from __future__ import annotations

import enum
import os
import sys
Expand Down Expand Up @@ -83,11 +85,11 @@ def to_arg(self) -> str:
return "--" + self.to_arg_str()

@staticmethod
def from_arg_str(arg: str) -> "CommandArg":
def from_arg_str(arg: str) -> CommandArg:
return CommandArg[arg.upper().replace("-", "_")]

@staticmethod
def arg_dict_to_enum_dict(**kwargs: str) -> dict["CommandArg", str]:
def arg_dict_to_enum_dict(**kwargs: str) -> dict[CommandArg, str]:
return {CommandArg.from_arg_str(k): v for k, v in kwargs.items()}


Expand Down
14 changes: 8 additions & 6 deletions puntgun/conf/secret.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import binascii
from pathlib import Path

Expand Down Expand Up @@ -59,21 +61,21 @@ class Config:
max_anystr_length = 100

@staticmethod
def from_environment() -> "TwitterAPISecrets":
def from_environment() -> TwitterAPISecrets:
return TwitterAPISecrets(
key=load_settings_from_environment_variables(twitter_api_key_name),
secret=load_settings_from_environment_variables(twitter_api_key_secret_name),
)

@staticmethod
def from_settings(pri_key: RSAPrivateKey) -> "TwitterAPISecrets":
def from_settings(pri_key: RSAPrivateKey) -> TwitterAPISecrets:
return TwitterAPISecrets(
key=load_and_decrypt_secret_from_settings(pri_key, twitter_api_key_name),
secret=load_and_decrypt_secret_from_settings(pri_key, twitter_api_key_secret_name),
)

@staticmethod
def from_input() -> "TwitterAPISecrets":
def from_input() -> TwitterAPISecrets:
print(GET_API_SECRETS_FROM_INPUT)
return TwitterAPISecrets(
key=util.get_secret_from_terminal("Api key"), secret=util.get_secret_from_terminal("Api key secret")
Expand Down Expand Up @@ -101,21 +103,21 @@ class Config:
max_anystr_length = 100

@staticmethod
def from_environment() -> "TwitterAccessTokenSecrets":
def from_environment() -> TwitterAccessTokenSecrets:
return TwitterAccessTokenSecrets(
token=load_settings_from_environment_variables(twitter_access_token_name),
secret=load_settings_from_environment_variables(twitter_access_token_secret_name),
)

@staticmethod
def from_settings(pri_key: RSAPrivateKey) -> "TwitterAccessTokenSecrets":
def from_settings(pri_key: RSAPrivateKey) -> TwitterAccessTokenSecrets:
return TwitterAccessTokenSecrets(
token=load_and_decrypt_secret_from_settings(pri_key, twitter_access_token_name),
secret=load_and_decrypt_secret_from_settings(pri_key, twitter_access_token_secret_name),
)

@staticmethod
def from_input(api_secrets: TwitterAPISecrets) -> "TwitterAccessTokenSecrets":
def from_input(api_secrets: TwitterAPISecrets) -> TwitterAccessTokenSecrets:
oauth1_user_handler = OAuth1UserHandler(api_secrets.key, api_secrets.secret, callback="oob")
print(AUTH_URL.format(auth_url=oauth1_user_handler.get_authorization_url()))
pin = util.get_input_from_terminal("PIN")
Expand Down
10 changes: 6 additions & 4 deletions puntgun/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
IMPROVE: More elegant way to generating a json format report file.
"""
from __future__ import annotations

import datetime
from typing import Any, List
from typing import Any

import orjson
from loguru import logger
Expand All @@ -41,7 +43,7 @@ def to_json(self) -> bytes:
return orjson.dumps({"type": self.type, "data": self.data})

@staticmethod
def parse_from_dict(conf: dict) -> "Record":
def parse_from_dict(conf: dict) -> Record:
"""
Assume that the parameter is already a dictionary type parsed from a json file.
"""
Expand All @@ -67,7 +69,7 @@ def to_record(self) -> Record:
raise NotImplementedError

@staticmethod
def parse_from_record(record: Record) -> "Recordable":
def parse_from_record(record: Record) -> Recordable:
"""Generate an instance from a record."""
raise NotImplementedError

Expand Down Expand Up @@ -102,7 +104,7 @@ def record(recordable: Recordable) -> None:
Recorder._write(recordable.to_record().to_json() + COMMA)

@staticmethod
def write_report_header(plans: List[Plan]) -> None:
def write_report_header(plans: list[Plan]) -> None:
"""
This paragraph works as the report file content's header
- for correctly formatting latter records in json format.
Expand Down
10 changes: 6 additions & 4 deletions puntgun/rules/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import datetime
import itertools
import sys
from typing import ClassVar, List
from typing import ClassVar

from pydantic import BaseModel, Field, root_validator
from reactivex import Observable
Expand All @@ -17,7 +19,7 @@ class FromConfig(BaseModel):
_keyword: ClassVar[str] = "corresponding_rule_name_in_config_of_this_rule"

@classmethod
def parse_from_config(cls, conf: dict) -> "FromConfig":
def parse_from_config(cls, conf: dict) -> FromConfig:
"""
Most rules have a dictionary structure of fields, their configurations are something like:
{ 'rule_name': {'field_1':1, 'field_2':2,...} }
Expand Down Expand Up @@ -76,7 +78,7 @@ def __call__(self) -> Observable:
raise NotImplementedError


def validate_required_fields_exist(rule_keyword: str, conf: dict, required_field_names: List[str]) -> None:
def validate_required_fields_exist(rule_keyword: str, conf: dict, required_field_names: list[str]) -> None:
"""
Custom configuration parsing process
- :class:`ConfigParser` sort of bypass the pydantic library's validation,
Expand All @@ -92,7 +94,7 @@ def validate_required_fields_exist(rule_keyword: str, conf: dict, required_field
raise ValueError(f"Missing required field(s) {missing} in configuration [{rule_keyword}]: {conf}")


def validate_fields_conflict(values: dict, field_groups: List[List[str]]) -> None:
def validate_fields_conflict(values: dict, field_groups: list[list[str]]) -> None:
"""
:param values: configuration dictionary
:param field_groups: no conflict inside each group
Expand Down

0 comments on commit 1e85bb9

Please sign in to comment.