Skip to content

Commit

Permalink
Add "all" option to logging (pytorch#100664)
Browse files Browse the repository at this point in the history
Adds the long-promised "all" option to logging.

Pull Request resolved: pytorch#100664
Approved by: https://github.com/lezcano
  • Loading branch information
mlazos authored and kiersten-stokes committed May 8, 2023
1 parent 9ce5360 commit 006202c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 11 deletions.
3 changes: 3 additions & 0 deletions docs/source/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ The following components and artifacts are configurable through the ``TORCH_LOGS
variable (see torch._logging.set_logs for the python API):

Components:
``all``
Special component which configures the default log level of all components. Default: ``logging.WARN``

``dynamo``
The log level for the TorchDynamo component. Default: ``logging.WARN``

Expand Down
14 changes: 14 additions & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ def test_open_registration_python_api(self, records):
logger.info("hi")
self.assertEqual(len(records), 1)

@make_logging_test(all=logging.DEBUG, dynamo=logging.INFO)
def test_all(self, _):
registry = torch._logging._internal.log_registry
state = torch._logging._internal.log_state

dynamo_qname = registry.log_alias_to_log_qname["dynamo"]
for logger_qname in torch._logging._internal.log_registry.get_log_qnames():
logger = logging.getLogger(logger_qname)

if logger_qname == dynamo_qname:
self.assertEqual(logger.level, logging.INFO)
else:
self.assertEqual(logger.level, logging.DEBUG)


# single record tests
exclusions = {
Expand Down
62 changes: 51 additions & 11 deletions torch/_logging/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ def clear(self):

def set_logs(
*,
dynamo: int = DEFAULT_LOG_LEVEL,
aot: int = DEFAULT_LOG_LEVEL,
inductor: int = DEFAULT_LOG_LEVEL,
all: Optional[int] = None,
dynamo: Optional[int] = None,
aot: Optional[int] = None,
dynamic: int = None,
inductor: int = None,
bytecode: bool = False,
aot_graphs: bool = False,
aot_joint_graph: bool = False,
Expand Down Expand Up @@ -169,15 +171,21 @@ def set_logs(
is set to a log level less than or equal to the log level of the artifact.
Keyword args:
dynamo (:class:`int`):
all (:class:`Optional[int]`):
The default log level for all components. Default: ``logging.WARN``
dynamo (:class:`Optional[int]`):
The log level for the TorchDynamo component. Default: ``logging.WARN``
aot (:class:`int`):
aot (:class:`Optional[int]`):
The log level for the AOTAutograd component. Default: ``logging.WARN``
inductor (:class:`int`):
inductor (:class:`Optional[int]`):
The log level for the TorchInductor component. Default: ``logging.WARN``
dynamic (:class:`Optional[int]`):
The log level for dynamic shapes. Default: ``logging.WARN``
bytecode (:class:`bool`):
Whether to emit the original and generated bytecode from TorchDynamo.
Default: ``False``
Expand Down Expand Up @@ -250,7 +258,25 @@ def set_logs(
modules = modules or {}

def _set_logs(**kwargs):
default_level = kwargs.pop("all", None)
if default_level:
if default_level not in logging._levelToName:
raise ValueError(
f"Unrecognized log level for kwarg all: {default_level}, valid level values "
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
)

# add any missing aliases to kwargs
for alias in log_registry.log_alias_to_log_qname.keys():
if alias not in kwargs:
kwargs[alias] = default_level
else:
default_level = DEFAULT_LOG_LEVEL

for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
if val is None:
val = default_level

if log_registry.is_artifact(alias):
if val:
log_state.enable_artifact(alias)
Expand All @@ -260,10 +286,10 @@ def _set_logs(**kwargs):
f"Unrecognized log level for log {alias}: {val}, valid level values "
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
)
if val != DEFAULT_LOG_LEVEL:
log_state.enable_log(
log_registry.log_alias_to_log_qname[alias], val
)

log_state.enable_log(log_registry.log_alias_to_log_qname[alias], val)
elif alias == "all":
continue
else:
raise ValueError(
f"Unrecognized log or artifact name passed to set_logs: {alias}"
Expand All @@ -272,9 +298,11 @@ def _set_logs(**kwargs):
_init_logs()

_set_logs(
all=all,
dynamo=dynamo,
aot=aot,
inductor=inductor,
dynamic=dynamic,
bytecode=bytecode,
aot_graphs=aot_graphs,
aot_joint_graph=aot_joint_graph,
Expand Down Expand Up @@ -357,7 +385,9 @@ def _validate_settings(settings):
def _invalid_settings_err_msg(settings):
entities = "\n " + "\n ".join(
itertools.chain(
log_registry.log_alias_to_log_qname.keys(), log_registry.artifact_names
["all"],
log_registry.log_alias_to_log_qname.keys(),
log_registry.artifact_names,
)
)
msg = (
Expand Down Expand Up @@ -392,14 +422,24 @@ def get_name_level_pair(name):
return clean_name, level

log_state = LogState()

for name in log_names:
name, level = get_name_level_pair(name)
if name == "all":
for log_qname in log_registry.get_log_qnames():
log_state.enable_log(log_qname, level)

for name in log_names:
name, level = get_name_level_pair(name)

if log_registry.is_log(name):
assert level is not None
log_qname = log_registry.log_alias_to_log_qname[name]
log_state.enable_log(log_qname, level)
elif log_registry.is_artifact(name):
log_state.enable_artifact(name)
elif name == "all":
continue
elif _is_valid_module(name):
if not _has_registered_parent(name):
log_registry.register_log(name, name)
Expand Down

0 comments on commit 006202c

Please sign in to comment.