diff --git a/puntgun/client.py b/puntgun/client.py index 2589921..005c17d 100644 --- a/puntgun/client.py +++ b/puntgun/client.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import datetime import functools import itertools from enum import Enum -from typing import Any, Callable, Iterator, List, TypeVar +from typing import Any, Callable, Iterator, TypeVar import tweepy from loguru import logger @@ -34,7 +36,7 @@ def __init__(self, title: str, ref_url: str, detail: str, parameter: str, value: self.value = value @staticmethod - def from_response(resp_error: dict) -> "TwitterApiError": + def from_response(resp_error: dict) -> TwitterApiError: # build an accurate error type according to the response content singleton_iter = (c for c in TwitterApiError.__subclasses__() if c.title == resp_error.get("title")) # if we haven't written a subclass for given error, return the generic error (in default value way) @@ -73,7 +75,7 @@ class TwitterApiErrors(Exception, Recordable): contains a list of :class:`TwitterApiError`. """ - def __init__(self, query_func_name: str, query_params: dict, resp_errors: List[dict]): + def __init__(self, query_func_name: str, query_params: dict, resp_errors: list[dict]): """Details for debugging via log.""" self.query_func_name = query_func_name self.query_params = query_params @@ -109,7 +111,7 @@ def to_record(self) -> Record: ) @staticmethod - def parse_from_record(record: Record) -> "TwitterApiErrors": + def parse_from_record(record: Record) -> TwitterApiErrors: data = record.data return TwitterApiErrors(data.get("query_func_name", ""), data.get("query_params", ()), data.get("errors", [])) @@ -270,7 +272,7 @@ def __init__(self, tweepy_client: tweepy.Client): @staticmethod @functools.lru_cache(maxsize=1) - def singleton() -> "Client": + def singleton() -> Client: secrets = secret.load_or_request_all_secrets(encrypto.load_or_generate_private_key()) return Client( tweepy.Client( @@ -281,7 +283,7 @@ def singleton() -> "Client": ) ) - def get_users_by_usernames(self, names: List[str]) -> List[User]: + def get_users_by_usernames(self, names: list[str]) -> list[User]: """ Query users information. **Rate limit: 900 / 15 min** @@ -292,7 +294,7 @@ def get_users_by_usernames(self, names: List[str]) -> List[User]: return response_to_users(self.clt.get_users(usernames=names, **USER_API_PARAMS)) - def get_users_by_ids(self, ids: List[int | str]) -> List[User]: + def get_users_by_ids(self, ids: list[int | str]) -> list[User]: """ Query users information. **Rate limit: 900 / 15 min** @@ -303,7 +305,7 @@ def get_users_by_ids(self, ids: List[int | str]) -> List[User]: return response_to_users(self.clt.get_users(ids=ids, **USER_API_PARAMS)) - def get_blocked(self) -> List[User]: + def get_blocked(self) -> list[User]: """ Get the latest blocking list of the current account. **Rate limit: 15 / 15 min** @@ -312,7 +314,7 @@ def get_blocked(self) -> List[User]: return query_paged_user_api(self.clt.get_blocked) @functools.lru_cache(maxsize=1) - def cached_blocked(self) -> List[User]: + def cached_blocked(self) -> list[User]: """ Call query method, cache them, and return the cache on latter calls. Since the tool may be constantly modifying the block list, @@ -322,10 +324,10 @@ def cached_blocked(self) -> List[User]: return self.get_blocked() @functools.lru_cache(maxsize=1) - def cached_blocked_id_list(self) -> List[int]: + def cached_blocked_id_list(self) -> list[int]: return [u.id for u in self.cached_blocked()] - def get_following(self, user_id: int | str) -> List[User]: + def get_following(self, user_id: int | str) -> list[User]: """ Get the latest following list of a user. **Rate limit: 15 / 15 min** @@ -334,14 +336,14 @@ def get_following(self, user_id: int | str) -> List[User]: return query_paged_user_api(self.clt.get_users_following, id=user_id) @functools.lru_cache(maxsize=1) - def cached_following(self) -> List[User]: + def cached_following(self) -> list[User]: return self.get_following(self.id) @functools.lru_cache(maxsize=1) - def cached_following_id_list(self) -> List[int]: + def cached_following_id_list(self) -> list[int]: return [u.id for u in self.cached_following()] - def get_follower(self, user_id: int | str) -> List[User]: + def get_follower(self, user_id: int | str) -> list[User]: """ Get the latest follower list of a user. **Rate limit: 15 / 15 min** @@ -350,11 +352,11 @@ def get_follower(self, user_id: int | str) -> List[User]: return query_paged_user_api(self.clt.get_users_followers, id=user_id) @functools.lru_cache(maxsize=1) - def cached_follower(self) -> List[User]: + def cached_follower(self) -> list[User]: return self.get_follower(self.id) @functools.lru_cache(maxsize=1) - def cached_follower_id_list(self) -> List[int]: + def cached_follower_id_list(self) -> list[int]: return [u.id for u in self.cached_follower()] def block_user_by_id(self, target_user_id: int | str) -> bool: @@ -383,7 +385,7 @@ def block_user_by_id(self, target_user_id: int | str) -> bool: # call the block api return self.clt.block(target_user_id=target_user_id).data["blocking"] - def get_tweets_by_ids(self, ids: List[int | str]) -> List[Tweet]: + def get_tweets_by_ids(self, ids: list[int | str]) -> list[Tweet]: """ Query tweets information. **Rate limit: 900 / 15 min** @@ -394,7 +396,7 @@ def get_tweets_by_ids(self, ids: List[int | str]) -> List[Tweet]: return response_to_tweets(self.clt.get_tweets(ids=ids, **TWEET_API_PARAMS)) - def get_users_who_like_tweet(self, tweet_id: int | str) -> List[User]: + def get_users_who_like_tweet(self, tweet_id: int | str) -> list[User]: """ Get a Tweet’s liking users (who liked this tweet). **Rate limit: 75 / 15 min** @@ -402,7 +404,7 @@ def get_users_who_like_tweet(self, tweet_id: int | str) -> List[User]: """ return query_paged_user_api(self.clt.get_liking_users, max_results=100, id=tweet_id) - def get_users_who_retweet_tweet(self, tweet_id: int | str) -> List[User]: + def get_users_who_retweet_tweet(self, tweet_id: int | str) -> list[User]: """ Get users who have retweeted a Tweet. **Rate limit: 75 / 15 min** @@ -419,7 +421,7 @@ def search_tweets( end_time: datetime.datetime = None, since_id: int = None, until_id: int = None, - ) -> List[Tweet]: + ) -> list[Tweet]: """ Search tweets with a query string. With Essential Twitter API access, the query length is limited up to 512 characters, @@ -443,7 +445,7 @@ def search_tweets( return [] -def response_to_users(resp: tweepy.Response) -> List[User]: +def response_to_users(resp: tweepy.Response) -> list[User]: """Build a list of :class:`User` instances from one response.""" if not resp.data: return [] @@ -459,7 +461,7 @@ def map_one(data: dict) -> User: return [map_one(d) for d in resp.data] -def response_to_tweets(resp: tweepy.Response) -> List[Tweet]: +def response_to_tweets(resp: tweepy.Response) -> list[Tweet]: """Build a list of :class:`Tweet` instances from one response.""" if not resp.data: return [] @@ -517,11 +519,11 @@ def map_one(data: dict) -> Tweet: return [map_one(d) for d in resp.data] -def query_paged_user_api(clt_func: Callable[..., Response], max_results: int = 1000, **kwargs: Any) -> List[User]: +def query_paged_user_api(clt_func: Callable[..., Response], max_results: int = 1000, **kwargs: Any) -> list[User]: return query_paged_entity_api(clt_func, USER_API_PARAMS, response_to_users, max_results=max_results, **kwargs) -def query_paged_tweet_api(clt_func: Callable[..., Response], times: int = None, **kwargs: Any) -> List[Tweet]: +def query_paged_tweet_api(clt_func: Callable[..., Response], times: int = None, **kwargs: Any) -> list[Tweet]: return query_paged_entity_api(clt_func, TWEET_API_PARAMS, response_to_tweets, times, max_results=100, **kwargs) @@ -531,11 +533,11 @@ def query_paged_tweet_api(clt_func: Callable[..., Response], times: int = None, def query_paged_entity_api( clt_func: Callable[..., Response], api_params: dict, - transforming_func: Callable[[Response], List[E]], + transforming_func: Callable[[Response], list[E]], times: int = None, max_results: int = 100, **kwargs: Any, -) -> List[E]: +) -> list[E]: # mix two part of params into one dict params = {} for k, v in api_params.items(): diff --git a/puntgun/commands.py b/puntgun/commands.py index a71b597..61f6548 100644 --- a/puntgun/commands.py +++ b/puntgun/commands.py @@ -1,4 +1,6 @@ """The implementation of commands""" +from __future__ import annotations + from pathlib import Path from loguru import logger diff --git a/puntgun/conf/config.py b/puntgun/conf/config.py index 91250ff..c08faf4 100644 --- a/puntgun/conf/config.py +++ b/puntgun/conf/config.py @@ -4,6 +4,8 @@ IMPROVE: Any proper way to unit test this module? I feel it is too implement-coupling to be valuable enough writing test cases. """ +from __future__ import annotations + import enum import os import sys @@ -83,11 +85,11 @@ def to_arg(self) -> str: return "--" + self.to_arg_str() @staticmethod - def from_arg_str(arg: str) -> "CommandArg": + def from_arg_str(arg: str) -> CommandArg: return CommandArg[arg.upper().replace("-", "_")] @staticmethod - def arg_dict_to_enum_dict(**kwargs: str) -> dict["CommandArg", str]: + def arg_dict_to_enum_dict(**kwargs: str) -> dict[CommandArg, str]: return {CommandArg.from_arg_str(k): v for k, v in kwargs.items()} diff --git a/puntgun/conf/secret.py b/puntgun/conf/secret.py index 00b3f43..c0d2ece 100644 --- a/puntgun/conf/secret.py +++ b/puntgun/conf/secret.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import binascii from pathlib import Path @@ -59,21 +61,21 @@ class Config: max_anystr_length = 100 @staticmethod - def from_environment() -> "TwitterAPISecrets": + def from_environment() -> TwitterAPISecrets: return TwitterAPISecrets( key=load_settings_from_environment_variables(twitter_api_key_name), secret=load_settings_from_environment_variables(twitter_api_key_secret_name), ) @staticmethod - def from_settings(pri_key: RSAPrivateKey) -> "TwitterAPISecrets": + def from_settings(pri_key: RSAPrivateKey) -> TwitterAPISecrets: return TwitterAPISecrets( key=load_and_decrypt_secret_from_settings(pri_key, twitter_api_key_name), secret=load_and_decrypt_secret_from_settings(pri_key, twitter_api_key_secret_name), ) @staticmethod - def from_input() -> "TwitterAPISecrets": + def from_input() -> TwitterAPISecrets: print(GET_API_SECRETS_FROM_INPUT) return TwitterAPISecrets( key=util.get_secret_from_terminal("Api key"), secret=util.get_secret_from_terminal("Api key secret") @@ -101,21 +103,21 @@ class Config: max_anystr_length = 100 @staticmethod - def from_environment() -> "TwitterAccessTokenSecrets": + def from_environment() -> TwitterAccessTokenSecrets: return TwitterAccessTokenSecrets( token=load_settings_from_environment_variables(twitter_access_token_name), secret=load_settings_from_environment_variables(twitter_access_token_secret_name), ) @staticmethod - def from_settings(pri_key: RSAPrivateKey) -> "TwitterAccessTokenSecrets": + def from_settings(pri_key: RSAPrivateKey) -> TwitterAccessTokenSecrets: return TwitterAccessTokenSecrets( token=load_and_decrypt_secret_from_settings(pri_key, twitter_access_token_name), secret=load_and_decrypt_secret_from_settings(pri_key, twitter_access_token_secret_name), ) @staticmethod - def from_input(api_secrets: TwitterAPISecrets) -> "TwitterAccessTokenSecrets": + def from_input(api_secrets: TwitterAPISecrets) -> TwitterAccessTokenSecrets: oauth1_user_handler = OAuth1UserHandler(api_secrets.key, api_secrets.secret, callback="oob") print(AUTH_URL.format(auth_url=oauth1_user_handler.get_authorization_url())) pin = util.get_input_from_terminal("PIN") diff --git a/puntgun/record.py b/puntgun/record.py index 878ff1f..97dfcce 100644 --- a/puntgun/record.py +++ b/puntgun/record.py @@ -17,8 +17,10 @@ IMPROVE: More elegant way to generating a json format report file. """ +from __future__ import annotations + import datetime -from typing import Any, List +from typing import Any import orjson from loguru import logger @@ -41,7 +43,7 @@ def to_json(self) -> bytes: return orjson.dumps({"type": self.type, "data": self.data}) @staticmethod - def parse_from_dict(conf: dict) -> "Record": + def parse_from_dict(conf: dict) -> Record: """ Assume that the parameter is already a dictionary type parsed from a json file. """ @@ -67,7 +69,7 @@ def to_record(self) -> Record: raise NotImplementedError @staticmethod - def parse_from_record(record: Record) -> "Recordable": + def parse_from_record(record: Record) -> Recordable: """Generate an instance from a record.""" raise NotImplementedError @@ -102,7 +104,7 @@ def record(recordable: Recordable) -> None: Recorder._write(recordable.to_record().to_json() + COMMA) @staticmethod - def write_report_header(plans: List[Plan]) -> None: + def write_report_header(plans: list[Plan]) -> None: """ This paragraph works as the report file content's header - for correctly formatting latter records in json format. diff --git a/puntgun/rules/base.py b/puntgun/rules/base.py index a98c20c..03bd2ff 100644 --- a/puntgun/rules/base.py +++ b/puntgun/rules/base.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import datetime import itertools import sys -from typing import ClassVar, List +from typing import ClassVar from pydantic import BaseModel, Field, root_validator from reactivex import Observable @@ -17,7 +19,7 @@ class FromConfig(BaseModel): _keyword: ClassVar[str] = "corresponding_rule_name_in_config_of_this_rule" @classmethod - def parse_from_config(cls, conf: dict) -> "FromConfig": + def parse_from_config(cls, conf: dict) -> FromConfig: """ Most rules have a dictionary structure of fields, their configurations are something like: { 'rule_name': {'field_1':1, 'field_2':2,...} } @@ -76,7 +78,7 @@ def __call__(self) -> Observable: raise NotImplementedError -def validate_required_fields_exist(rule_keyword: str, conf: dict, required_field_names: List[str]) -> None: +def validate_required_fields_exist(rule_keyword: str, conf: dict, required_field_names: list[str]) -> None: """ Custom configuration parsing process - :class:`ConfigParser` sort of bypass the pydantic library's validation, @@ -92,7 +94,7 @@ def validate_required_fields_exist(rule_keyword: str, conf: dict, required_field raise ValueError(f"Missing required field(s) {missing} in configuration [{rule_keyword}]: {conf}") -def validate_fields_conflict(values: dict, field_groups: List[List[str]]) -> None: +def validate_fields_conflict(values: dict, field_groups: list[list[str]]) -> None: """ :param values: configuration dictionary :param field_groups: no conflict inside each group diff --git a/puntgun/rules/config_parser.py b/puntgun/rules/config_parser.py index adbaaf0..0f7c9ca 100644 --- a/puntgun/rules/config_parser.py +++ b/puntgun/rules/config_parser.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import importlib import pkgutil -from typing import Any, List, Type, TypeVar +from typing import Any, TypeVar from loguru import logger from pydantic import ValidationError @@ -35,12 +37,12 @@ class ConfigParser: # so I guess it's ok to use a class variable to store the errors, # and use this class as singleton pattern. # Sort of inconvenient when unit testing. - _errors: List[Exception] = [] + _errors: list[Exception] = [] _T = TypeVar("_T", bound=FromConfig) @staticmethod - def parse(conf: dict, expected_type: _T | Type[_T]) -> _T: + def parse(conf: dict, expected_type: _T | type[_T]) -> _T: """ Take a piece of configuration and the expected type from caller, recognize which rule it is and parse it into corresponding rule instance. @@ -78,7 +80,7 @@ def generate_placeholder_instance() -> Any: return generate_placeholder_instance() @staticmethod - def errors() -> List[Exception]: + def errors() -> list[Exception]: """Get errors occurred when paring plan configuration""" return ConfigParser._errors diff --git a/puntgun/rules/data.py b/puntgun/rules/data.py index f411097..aa5394d 100644 --- a/puntgun/rules/data.py +++ b/puntgun/rules/data.py @@ -92,7 +92,7 @@ def from_response(resp_data: Mapping, pinned_tweet: "Tweet" = None) -> "User": following_count=public_metrics.get("following_count", 0), tweet_count=public_metrics.get("tweet_count", 0), pinned_tweet_text=pinned_tweet.text, - pinned_tweet=pinned_tweet + pinned_tweet=pinned_tweet, ) class Config: @@ -329,7 +329,7 @@ def corresponding_tweet(_id: str) -> "Tweet": mediums=mediums, polls=polls, place=place, - related_tweets=relations + related_tweets=relations, ) class Config: diff --git a/puntgun/rules/user/filter_rules.py b/puntgun/rules/user/filter_rules.py index 6ea2601..a23f610 100644 --- a/puntgun/rules/user/filter_rules.py +++ b/puntgun/rules/user/filter_rules.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import re from typing import ClassVar @@ -87,7 +89,7 @@ class CreatedAfterUserFilterRule(UserFilterRule): _keyword: ClassVar[str] = "created_after" @classmethod - def parse_from_config(cls, conf: dict) -> "CreatedUserFilterRule": + def parse_from_config(cls, conf: dict) -> CreatedUserFilterRule: return CreatedUserFilterRule(after=conf[cls._keyword]) @@ -96,7 +98,7 @@ class CreatedWithinDaysUserFilterRule(UserFilterRule): within_days: int @classmethod - def parse_from_config(cls, conf: dict) -> "CreatedWithinDaysUserFilterRule": + def parse_from_config(cls, conf: dict) -> CreatedWithinDaysUserFilterRule: return CreatedWithinDaysUserFilterRule(within_days=conf[cls._keyword]) def __call__(self, user: User) -> RuleResult: @@ -109,7 +111,7 @@ class TextMatchUserFilterRule(UserFilterRule): pattern: str @classmethod - def parse_from_config(cls, conf: dict) -> "TextMatchUserFilterRule": + def parse_from_config(cls, conf: dict) -> TextMatchUserFilterRule: return cls(pattern=conf[cls._keyword]) def __call__(self, user: User) -> RuleResult: diff --git a/puntgun/rules/user/plan.py b/puntgun/rules/user/plan.py index 773e59e..5935ba5 100644 --- a/puntgun/rules/user/plan.py +++ b/puntgun/rules/user/plan.py @@ -1,4 +1,6 @@ -from typing import ClassVar, List +from __future__ import annotations + +from typing import ClassVar import reactivex as rx from loguru import logger @@ -20,7 +22,7 @@ class UserPlanResult(Recordable): - def __init__(self, plan_id: int, target: User, filtering_result: RuleResult, action_results: List[RuleResult]): + def __init__(self, plan_id: int, target: User, filtering_result: RuleResult, action_results: list[RuleResult]): self.plan_id = plan_id self.target = target self.filtering_result = filtering_result @@ -48,7 +50,7 @@ def to_record(self) -> Record: ) @staticmethod - def parse_from_record(record: Record) -> "UserPlanResult": + def parse_from_record(record: Record) -> UserPlanResult: user: dict = record.data.get("target", {}) filter_rule_record = record.data.get("decisive_filter_rule", {}) action_rule_results: list = record.data.get("action_rule_results", []) @@ -87,7 +89,7 @@ def __call__(self, user: User) -> bool: return True @classmethod - def parse_from_config(cls, conf: dict) -> "UserPlan": + def parse_from_config(cls, conf: dict) -> UserPlan: # we won't directly extract values from configuration and assign them to fields, # so custom validation is needed # as we can't use pydantic library's validating function on fields. diff --git a/puntgun/rules/user/rule_sets.py b/puntgun/rules/user/rule_sets.py index aa36784..6275150 100644 --- a/puntgun/rules/user/rule_sets.py +++ b/puntgun/rules/user/rule_sets.py @@ -6,7 +6,9 @@ so you can make complex cascading execution order tree with them. It's the composite pattern I guess. """ -from typing import Callable, List, Type +from __future__ import annotations + +from typing import Callable import reactivex as rx from loguru import logger @@ -29,10 +31,10 @@ class UserSourceRuleResultMergingSet(UserSourceRule): """ _keyword = "any_of" - rules: List[UserSourceRule] + rules: list[UserSourceRule] @classmethod - def parse_from_config(cls, conf: dict) -> "UserSourceRuleResultMergingSet": + def parse_from_config(cls, conf: dict) -> UserSourceRuleResultMergingSet: return cls(rules=[ConfigParser.parse(c, UserSourceRule) for c in conf["any_of"]]) def __call__(self) -> Observable[User]: @@ -48,11 +50,11 @@ def __call__(self) -> Observable[User]: class UserFilterRuleSet(BaseModel): - immediate_rules: List[UserFilterRule] - slow_rules: List[UserFilterRule] + immediate_rules: list[UserFilterRule] + slow_rules: list[UserFilterRule] @staticmethod - def divide_and_construct(cls: Type["UserFilterRuleSet"], rules: List[UserFilterRule]) -> "UserFilterRuleSet": + def divide_and_construct(cls: type[UserFilterRuleSet], rules: list[UserFilterRule]) -> UserFilterRuleSet: return cls( slow_rules=[r for r in rules if isinstance(r, NeedClientMixin)], immediate_rules=[r for r in rules if not isinstance(r, NeedClientMixin)], @@ -84,7 +86,7 @@ class UserFilterRuleAllOfSet(UserFilterRuleSet, UserFilterRule, NeedClientMixin) _keyword = "all_of" @classmethod - def parse_from_config(cls, conf: dict) -> "UserFilterRuleSet": + def parse_from_config(cls, conf: dict) -> UserFilterRuleSet: return UserFilterRuleSet.divide_and_construct( cls, [ConfigParser.parse(c, UserFilterRule) for c in conf["all_of"]] ) @@ -114,7 +116,7 @@ class UserFilterRuleAnyOfSet(UserFilterRuleSet, UserFilterRule, NeedClientMixin) _keyword = "any_of" @classmethod - def parse_from_config(cls, conf: dict) -> "UserFilterRuleSet": + def parse_from_config(cls, conf: dict) -> UserFilterRuleSet: return UserFilterRuleSet.divide_and_construct( cls, [ConfigParser.parse(c, UserFilterRule) for c in conf["any_of"]] ) @@ -140,13 +142,13 @@ class UserActionRuleResultCollectingSet(UserActionRule): """ _keyword = "all_of" - rules: List[UserActionRule] + rules: list[UserActionRule] @classmethod - def parse_from_config(cls, conf: dict) -> "UserActionRuleResultCollectingSet": + def parse_from_config(cls, conf: dict) -> UserActionRuleResultCollectingSet: return cls(rules=[ConfigParser.parse(c, UserActionRule) for c in conf["all_of"]]) - def __call__(self, user: User) -> Observable[List[RuleResult]]: + def __call__(self, user: User) -> Observable[list[RuleResult]]: action_results = [rx.start(execution_wrapper(user, r)) for r in self.rules] return rx.merge(*action_results).pipe( # collect them into one list diff --git a/puntgun/rules/user/source_rules.py b/puntgun/rules/user/source_rules.py index 5730f73..5e6751a 100644 --- a/puntgun/rules/user/source_rules.py +++ b/puntgun/rules/user/source_rules.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import itertools -from typing import ClassVar, List +from typing import ClassVar import reactivex as rx from loguru import logger @@ -32,7 +34,7 @@ class NameUserSourceRule(UserSourceRule, NeedClientMixin): """ _keyword: ClassVar[str] = "names" - names: List[str] + names: list[str] def __call__(self) -> rx.Observable[User]: return rx.from_iterable(self.names).pipe( @@ -48,7 +50,7 @@ def __call__(self) -> rx.Observable[User]: ) @classmethod - def parse_from_config(cls, conf: dict) -> "NameUserSourceRule": + def parse_from_config(cls, conf: dict) -> NameUserSourceRule: """the config is { 'names': [...] }""" return cls.parse_obj(conf) @@ -61,7 +63,7 @@ class IdUserSourceRule(UserSourceRule, NeedClientMixin): """ _keyword: ClassVar[str] = "ids" - ids: List[int | str] + ids: list[int | str] def __call__(self) -> rx.Observable[User]: return rx.from_iterable(self.ids).pipe( @@ -74,7 +76,7 @@ def __call__(self) -> rx.Observable[User]: ) @classmethod - def parse_from_config(cls, conf: dict) -> "IdUserSourceRule": + def parse_from_config(cls, conf: dict) -> IdUserSourceRule: return cls.parse_obj(conf) @@ -90,7 +92,7 @@ class MyFollowerUserSourceRule(UserSourceRule, NeedClientMixin): after_user: str | None @classmethod - def parse_from_config(cls, conf: dict) -> "FromConfig": + def parse_from_config(cls, conf: dict) -> FromConfig: fields = conf.get(cls._keyword) validate_fields_conflict(fields, [["last"], ["first"], ["after_user"]]) return cls.parse_obj(fields) @@ -104,7 +106,7 @@ def __call__(self) -> rx.Observable[User]: # if no field, return all followers return rx.from_iterable(followers) - def _take_part_of_followers(self, followers: List[User]) -> List[User]: + def _take_part_of_followers(self, followers: list[User]) -> list[User]: if self.last: # the follower API response puts newer followers on list head. return followers[: self.last] diff --git a/puntgun/runner.py b/puntgun/runner.py index 2dab52b..beac1ca 100644 --- a/puntgun/runner.py +++ b/puntgun/runner.py @@ -2,8 +2,10 @@ Plans runner at the highest abstraction level of the tool. Constructing plans, executing plans, collecting and recording plan results... """ +from __future__ import annotations + import sys -from typing import Any, List +from typing import Any import reactivex as rx from loguru import logger @@ -87,9 +89,9 @@ def get_and_validate_plan_config() -> list[dict]: {plan_num} plans found.""" -def parse_plans_config(_plans_config: list[dict]) -> List[Plan]: +def parse_plans_config(_plans_config: list[dict]) -> list[Plan]: """Let the ConfigParser recursively constructing plan instances and rule instances inside plans.""" - plans: List[Plan] = [ConfigParser.parse(p, Plan) for p in _plans_config] + plans: list[Plan] = [ConfigParser.parse(p, Plan) for p in _plans_config] # Can't continue without a zero-error plan configuration if ConfigParser.errors(): @@ -101,7 +103,7 @@ def parse_plans_config(_plans_config: list[dict]) -> List[Plan]: return plans -def execute_plans(plans: List[Plan]) -> None: +def execute_plans(plans: list[Plan]) -> None: def on_error(e: Exception) -> None: logger.error("Error occurred when executing plan", e) raise e diff --git a/tests/rules/test_user_filter_rules.py b/tests/rules/test_user_filter_rules.py index dca8b99..796305d 100644 --- a/tests/rules/test_user_filter_rules.py +++ b/tests/rules/test_user_filter_rules.py @@ -4,8 +4,9 @@ Most functions of these rules are tested when testing delegating agents, so only basic "parsing" logic will be tested in this test module. """ +from __future__ import annotations + import datetime -from typing import List import pytest from dynaconf import Dynaconf @@ -86,7 +87,7 @@ def wrapper(content: str): @pytest.fixture def run_assert(self, config_plan_file_with): def wrapper(regex: str, text: str, expect: bool): - def user_texts(_texts: List[str]): + def user_texts(_texts: list[str]): words = [t if t else "" for t in _texts[0:3]] return User(name=words[0], description=words[1], pinned_tweet_text=words[2]) diff --git a/tests/rules/test_user_rule_sets.py b/tests/rules/test_user_rule_sets.py index 63b89ec..4773e19 100644 --- a/tests/rules/test_user_rule_sets.py +++ b/tests/rules/test_user_rule_sets.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import time -from typing import List import pytest import reactivex as rx @@ -246,7 +247,7 @@ def __call__(self, user: User): class TestUserActionRuleResultCollectingSet: def test_result_aggregating(self): - def action_ruleset_result_checker(results: List[RuleResult]): + def action_ruleset_result_checker(results: list[RuleResult]): """For user filter rule sets testing.""" action_ruleset_result_checker.call_count = 0 diff --git a/tests/test_runner.py b/tests/test_runner.py index 1b0ed7d..101667c 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -1,4 +1,6 @@ -from typing import ClassVar, List +from __future__ import annotations + +from typing import ClassVar import pytest import reactivex as rx @@ -68,7 +70,7 @@ class TRule(FromConfig): class TPlan(Plan): _keyword: ClassVar[str] = "runner_test_plan" - rules: List[FromConfig] + rules: list[FromConfig] def __call__(self) -> Observable[TResult]: return rx.from_iterable([r.f for r in self.rules]).pipe(op.map(lambda i: TResult(i)))