Skip to content

Commit

Permalink
chore(typing): Add explicit None return type (#1915)
Browse files Browse the repository at this point in the history
* chore(typing): Add explicit None return type

If a function doesn't return anything, specify its
return type as None.

* Apply suggestion by @basepi to make black pass

Co-authored-by: Colton Myers <colton@basepi.net>

---------

Co-authored-by: Colton Myers <colton@basepi.net>
  • Loading branch information
orsinium and basepi committed Oct 9, 2023
1 parent 7af9da9 commit 519a107
Show file tree
Hide file tree
Showing 73 changed files with 270 additions and 268 deletions.
16 changes: 8 additions & 8 deletions elasticapm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Client(object):

logger = get_logger("elasticapm")

def __init__(self, config=None, **inline):
def __init__(self, config=None, **inline) -> None:
# configure loggers first
cls = self.__class__
self.logger = get_logger("%s.%s" % (cls.__module__, cls.__name__))
Expand Down Expand Up @@ -228,7 +228,7 @@ def __init__(self, config=None, **inline):
# Save this Client object as the global CLIENT_SINGLETON
set_client(self)

def start_threads(self):
def start_threads(self) -> None:
current_pid = os.getpid()
if self._pid != current_pid:
with self._thread_starter_lock:
Expand Down Expand Up @@ -285,7 +285,7 @@ def capture_exception(self, exc_info=None, handled=True, **kwargs):
"""
return self.capture("Exception", exc_info=exc_info, handled=handled, **kwargs)

def queue(self, event_type, data, flush=False):
def queue(self, event_type, data, flush=False) -> None:
if self.config.disable_send:
return
self.start_threads()
Expand Down Expand Up @@ -331,7 +331,7 @@ def end_transaction(self, name=None, result="", duration=None):
transaction = self.tracer.end_transaction(result, name, duration=duration)
return transaction

def close(self):
def close(self) -> None:
if self.config.enabled:
with self._thread_starter_lock:
for _, manager in sorted(self._thread_managers.items(), key=lambda item: item[1].start_stop_order):
Expand Down Expand Up @@ -665,7 +665,7 @@ def _get_stack_info_for_trace(
locals_processor_func=locals_processor_func,
)

def _excepthook(self, type_, value, traceback):
def _excepthook(self, type_, value, traceback) -> None:
try:
self.original_excepthook(type_, value, traceback)
except Exception:
Expand Down Expand Up @@ -701,7 +701,7 @@ def should_ignore_topic(self, topic: str) -> bool:
return True
return False

def check_python_version(self):
def check_python_version(self) -> None:
v = tuple(map(int, platform.python_version_tuple()[:2]))
if v < (3, 6):
warnings.warn("The Elastic APM agent only supports Python 3.6+", DeprecationWarning)
Expand Down Expand Up @@ -736,7 +736,7 @@ def server_version(self):
return self._server_version

@server_version.setter
def server_version(self, new_version):
def server_version(self, new_version) -> None:
if new_version and len(new_version) < 3:
self.logger.debug("APM Server version is too short, padding with zeros")
new_version = new_version + (0,) * (3 - len(new_version))
Expand All @@ -754,7 +754,7 @@ def get_client() -> Client:
return CLIENT_SINGLETON


def set_client(client: Client):
def set_client(client: Client) -> None:
global CLIENT_SINGLETON
if CLIENT_SINGLETON:
logger = get_logger("elasticapm")
Expand Down
56 changes: 28 additions & 28 deletions elasticapm/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@


class ConfigurationError(ValueError):
def __init__(self, msg, field_name):
def __init__(self, msg, field_name) -> None:
self.field_name = field_name
super(ValueError, self).__init__(msg)

Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
callbacks_on_default=True,
default=None,
required=False,
):
) -> None:
self.type = type
self.dict_key = dict_key
self.validators = validators
Expand All @@ -138,7 +138,7 @@ def __get__(self, instance, owner):
else:
return self.default

def __set__(self, config_instance, value):
def __set__(self, config_instance, value) -> None:
value = self._validate(config_instance, value)
self._callback_if_changed(config_instance, value)
config_instance._values[self.dict_key] = value
Expand All @@ -159,7 +159,7 @@ def _validate(self, instance, value):
instance._errors.pop(self.dict_key, None)
return value

def _callback_if_changed(self, instance, new_value):
def _callback_if_changed(self, instance, new_value) -> None:
"""
If the value changed (checked against instance._values[self.dict_key]),
then run the callback function (if defined)
Expand All @@ -184,11 +184,11 @@ def call_callbacks(self, old_value, new_value, config_instance):


class _ListConfigValue(_ConfigValue):
def __init__(self, dict_key, list_separator=",", **kwargs):
def __init__(self, dict_key, list_separator=",", **kwargs) -> None:
self.list_separator = list_separator
super(_ListConfigValue, self).__init__(dict_key, **kwargs)

def __set__(self, instance, value):
def __set__(self, instance, value) -> None:
if isinstance(value, str):
value = value.split(self.list_separator)
elif value is not None:
Expand All @@ -200,12 +200,12 @@ def __set__(self, instance, value):


class _DictConfigValue(_ConfigValue):
def __init__(self, dict_key, item_separator=",", keyval_separator="=", **kwargs):
def __init__(self, dict_key, item_separator=",", keyval_separator="=", **kwargs) -> None:
self.item_separator = item_separator
self.keyval_separator = keyval_separator
super(_DictConfigValue, self).__init__(dict_key, **kwargs)

def __set__(self, instance, value):
def __set__(self, instance, value) -> None:
if isinstance(value, str):
items = (item.split(self.keyval_separator) for item in value.split(self.item_separator))
value = {key.strip(): self.type(val.strip()) for key, val in items}
Expand All @@ -217,12 +217,12 @@ def __set__(self, instance, value):


class _BoolConfigValue(_ConfigValue):
def __init__(self, dict_key, true_string="true", false_string="false", **kwargs):
def __init__(self, dict_key, true_string="true", false_string="false", **kwargs) -> None:
self.true_string = true_string
self.false_string = false_string
super(_BoolConfigValue, self).__init__(dict_key, **kwargs)

def __set__(self, instance, value):
def __set__(self, instance, value) -> None:
if isinstance(value, str):
if value.lower() == self.true_string:
value = True
Expand All @@ -240,7 +240,7 @@ class _DurationConfigValue(_ConfigValue):
("m", 60),
)

def __init__(self, dict_key, allow_microseconds=False, unitless_factor=None, **kwargs):
def __init__(self, dict_key, allow_microseconds=False, unitless_factor=None, **kwargs) -> None:
self.type = None # no type coercion
used_units = self.units if allow_microseconds else self.units[1:]
pattern = "|".join(unit[0] for unit in used_units)
Expand All @@ -258,15 +258,15 @@ def __init__(self, dict_key, allow_microseconds=False, unitless_factor=None, **k
validators.insert(0, duration_validator)
super().__init__(dict_key, validators=validators, **kwargs)

def __set__(self, config_instance, value):
def __set__(self, config_instance, value) -> None:
value = self._validate(config_instance, value)
value = timedelta(seconds=float(value))
self._callback_if_changed(config_instance, value)
config_instance._values[self.dict_key] = value


class RegexValidator(object):
def __init__(self, regex, verbose_pattern=None):
def __init__(self, regex, verbose_pattern=None) -> None:
self.regex = regex
self.verbose_pattern = verbose_pattern or regex

Expand All @@ -279,7 +279,7 @@ def __call__(self, value, field_name):


class UnitValidator(object):
def __init__(self, regex, verbose_pattern, unit_multipliers):
def __init__(self, regex, verbose_pattern, unit_multipliers) -> None:
self.regex = regex
self.verbose_pattern = verbose_pattern
self.unit_multipliers = unit_multipliers
Expand Down Expand Up @@ -307,7 +307,7 @@ class PrecisionValidator(object):
begin with), use the minimum instead.
"""

def __init__(self, precision=0, minimum=None):
def __init__(self, precision=0, minimum=None) -> None:
self.precision = precision
self.minimum = minimum

Expand All @@ -329,7 +329,7 @@ def __call__(self, value, field_name):


class ExcludeRangeValidator(object):
def __init__(self, range_start, range_end, range_desc):
def __init__(self, range_start, range_end, range_desc) -> None:
self.range_start = range_start
self.range_end = range_end
self.range_desc = range_desc
Expand Down Expand Up @@ -363,7 +363,7 @@ class EnumerationValidator(object):
of valid string options.
"""

def __init__(self, valid_values, case_sensitive=False):
def __init__(self, valid_values, case_sensitive=False) -> None:
"""
valid_values
List of valid string values for the config value
Expand All @@ -389,7 +389,7 @@ def __call__(self, value, field_name):
return ret


def _log_level_callback(dict_key, old_value, new_value, config_instance):
def _log_level_callback(dict_key, old_value, new_value, config_instance) -> None:
elasticapm_logger = logging.getLogger("elasticapm")
elasticapm_logger.setLevel(log_levels_map.get(new_value, 100))

Expand All @@ -408,7 +408,7 @@ def _log_level_callback(dict_key, old_value, new_value, config_instance):
elasticapm_logger.addHandler(filehandler)


def _log_ecs_reformatting_callback(dict_key, old_value, new_value, config_instance):
def _log_ecs_reformatting_callback(dict_key, old_value, new_value, config_instance) -> None:
"""
If ecs_logging is installed and log_ecs_reformatting is set to "override", we should
set the ecs_logging.StdlibFormatter as the formatted for every handler in
Expand Down Expand Up @@ -439,7 +439,7 @@ def _log_ecs_reformatting_callback(dict_key, old_value, new_value, config_instan
class _ConfigBase(object):
_NO_VALUE = object() # sentinel object

def __init__(self, config_dict=None, env_dict=None, inline_dict=None, copy=False):
def __init__(self, config_dict=None, env_dict=None, inline_dict=None, copy=False) -> None:
"""
config_dict
Configuration dict as is common for frameworks such as flask and django.
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(self, config_dict=None, env_dict=None, inline_dict=None, copy=False
if not copy:
self.update(config_dict, env_dict, inline_dict, initial=True)

def update(self, config_dict=None, env_dict=None, inline_dict=None, initial=False):
def update(self, config_dict=None, env_dict=None, inline_dict=None, initial=False) -> None:
if config_dict is None:
config_dict = {}
if env_dict is None:
Expand Down Expand Up @@ -508,7 +508,7 @@ def update(self, config_dict=None, env_dict=None, inline_dict=None, initial=Fals
)
self.call_pending_callbacks()

def call_pending_callbacks(self):
def call_pending_callbacks(self) -> None:
"""
Call callbacks for config options matching list of tuples:
Expand All @@ -523,7 +523,7 @@ def values(self):
return self._values

@values.setter
def values(self, values):
def values(self, values) -> None:
self._values = values

@property
Expand Down Expand Up @@ -717,7 +717,7 @@ class VersionedConfig(ThreadManager):
"start_stop_order",
)

def __init__(self, config_object, version, transport=None):
def __init__(self, config_object, version, transport=None) -> None:
"""
Create a new VersionedConfig with an initial Config object
:param config_object: the initial Config object
Expand Down Expand Up @@ -748,7 +748,7 @@ def update(self, version: str, **config):
else:
return new_config.errors

def reset(self):
def reset(self) -> None:
"""
Reset state to the original configuration
Expand Down Expand Up @@ -776,7 +776,7 @@ def changed(self) -> bool:
def __getattr__(self, item):
return getattr(self._config, item)

def __setattr__(self, name, value):
def __setattr__(self, name, value) -> None:
if name not in self.__slots__:
setattr(self._config, name, value)
else:
Expand Down Expand Up @@ -812,14 +812,14 @@ def update_config(self):

return next_run

def start_thread(self, pid=None):
def start_thread(self, pid=None) -> None:
self._update_thread = IntervalTimer(
self.update_config, 1, "eapm conf updater", daemon=True, evaluate_function_interval=True
)
self._update_thread.start()
super(VersionedConfig, self).start_thread(pid=pid)

def stop_thread(self):
def stop_thread(self) -> None:
if self._update_thread:
self._update_thread.cancel()
self._update_thread = None
Expand Down
4 changes: 2 additions & 2 deletions elasticapm/contrib/aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


class ElasticAPM:
def __init__(self, app, client=None):
def __init__(self, app, client=None) -> None:
if not client:
config = app.get("ELASTIC_APM", {})
config.setdefault("framework_name", "aiohttp")
Expand All @@ -45,7 +45,7 @@ def __init__(self, app, client=None):
self.client = client
self.install_tracing(app, client)

def install_tracing(self, app, client):
def install_tracing(self, app, client) -> None:
from elasticapm.contrib.aiohttp.middleware import tracing_middleware

app.middlewares.insert(0, tracing_middleware(app, client))
Expand Down
4 changes: 2 additions & 2 deletions elasticapm/contrib/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

def wrap_send(send, middleware):
@functools.wraps(send)
async def wrapped_send(message):
async def wrapped_send(message) -> None:
if message.get("type") == "http.response.start":
await set_context(lambda: middleware.get_data_from_response(message, constants.TRANSACTION), "response")
result = "HTTP {}xx".format(message["status"] // 100)
Expand Down Expand Up @@ -218,7 +218,7 @@ async def get_data_from_response(self, message: dict, event_type: str) -> dict:

return result

def set_transaction_name(self, method: str, url: str):
def set_transaction_name(self, method: str, url: str) -> None:
"""
Default implementation sets transaction name to "METHOD unknown route".
Subclasses may add framework specific naming.
Expand Down
4 changes: 2 additions & 2 deletions elasticapm/contrib/asyncio/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ async def __aenter__(self) -> Optional[SpanType]:

async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
):
) -> None:
self.handle_exit(exc_type, exc_val, exc_tb)


async def set_context(data, key="custom"):
async def set_context(data, key="custom") -> None:
"""
Asynchronous copy of elasticapm.traces.set_context().
Attach contextual data to the current transaction and errors that happen during the current transaction.
Expand Down
Loading

0 comments on commit 519a107

Please sign in to comment.