Skip to content

Commit

Permalink
update usage of deprecated profiler (Lightning-AI#5010)
Browse files Browse the repository at this point in the history
* drop deprecated profiler

* lut

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
Borda and s-rog committed Dec 10, 2020
1 parent cdbddbe commit 77fb425
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pl_examples/domain_templates/imagenet.py
Expand Up @@ -237,7 +237,7 @@ def run_cli():
help='seed for initializing training.')
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
parser.set_defaults(
profiler=True,
profiler="simple",
deterministic=True,
max_epochs=90,
)
Expand Down
19 changes: 11 additions & 8 deletions pytorch_lightning/trainer/connectors/profiler_connector.py
Expand Up @@ -18,6 +18,11 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
}


class ProfilerConnector:

Expand All @@ -28,9 +33,9 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):

if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
"are valid values for `Trainer`'s `profiler` parameter. "
f"Received {profiler} which is of type {type(profiler)}.")
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`"
" are valid values for `Trainer`'s `profiler` parameter."
f" Received {profiler} which is of type {type(profiler)}.")

if isinstance(profiler, bool):
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
Expand All @@ -39,11 +44,9 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
profiler = profiler.lower()
if profiler == "simple":
profiler = SimpleProfiler()
elif profiler == "advanced":
profiler = AdvancedProfiler()
if profiler.lower() in PROFILERS:
profiler_class = PROFILERS[profiler.lower()]
profiler = profiler_class()
else:
raise ValueError("When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'")
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Expand Up @@ -1476,6 +1476,6 @@ def test_trainer_profiler_incorrect_str_arg():
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(MisconfigurationException,
match=r"Only None, bool, str and subclasses of `BaseProfiler` "
r"are valid values for `Trainer`'s `profiler` parameter. *"):
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)

0 comments on commit 77fb425

Please sign in to comment.