Skip to content

Commit caf7893

Browse files
ref: modular is_overridden (Lightning-AI#3290)
* ref: modular is_overridden * ref: modular is_overridden * ref: modular is_overridden * ref: modular is_overridden
1 parent b0f77a7 commit caf7893

File tree

8 files changed

+24
-57
lines changed

8 files changed

+24
-57
lines changed

pytorch_lightning/trainer/callback_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
1919
from pytorch_lightning.loggers import LightningLoggerBase
2020
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21+
from pytorch_lightning.utilities.model_utils import is_overridden
22+
from pytorch_lightning.core.lightning import LightningModule
2123

2224

2325
class TrainerCallbackConfigMixin(ABC):
@@ -41,13 +43,13 @@ def save_checkpoint(self, *args):
4143
"""Warning: this is just empty shell for code implemented in other class."""
4244

4345
@abstractmethod
44-
def is_overridden(self, *args):
46+
def get_model(self) -> LightningModule:
4547
"""Warning: this is just empty shell for code implemented in other class."""
4648

4749
def configure_checkpoint_callback(self, checkpoint_callback):
4850
if checkpoint_callback is True:
4951
# when no val step is defined, use 'loss' otherwise 'val_loss'
50-
train_step_only = not self.is_overridden('validation_step')
52+
train_step_only = not is_overridden('validation_step', self.get_model())
5153
monitor_key = 'loss' if train_step_only else 'val_loss'
5254
checkpoint_callback = ModelCheckpoint(
5355
filepath=None,

pytorch_lightning/trainer/data_loading.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.debugging import InternalDebugger
29+
from pytorch_lightning.utilities.model_utils import is_overridden
2930

3031

3132
try:
@@ -78,10 +79,6 @@ class TrainerDataLoadingMixin(ABC):
7879
distributed_backend: Optional[str]
7980
dev_debugger: InternalDebugger
8081

81-
@abstractmethod
82-
def is_overridden(self, *args):
83-
"""Warning: this is just empty shell for code implemented in other class."""
84-
8582
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
8683
on_windows = platform.system() == 'Windows'
8784

@@ -305,8 +302,8 @@ def reset_val_dataloader(self, model: LightningModule) -> None:
305302
Args:
306303
model: The current `LightningModule`
307304
"""
308-
has_loader = self.is_overridden('val_dataloader', model)
309-
has_step = self.is_overridden('validation_step', model)
305+
has_loader = is_overridden('val_dataloader', model)
306+
has_step = is_overridden('validation_step', model)
310307
if has_loader and has_step:
311308
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
312309

@@ -316,8 +313,8 @@ def reset_test_dataloader(self, model) -> None:
316313
Args:
317314
model: The current `LightningModule`
318315
"""
319-
has_loader = self.is_overridden('test_dataloader', model)
320-
has_step = self.is_overridden('test_step', model)
316+
has_loader = is_overridden('test_dataloader', model)
317+
has_step = is_overridden('test_step', model)
321318
if has_loader and has_step:
322319
self.num_test_batches, self.test_dataloaders =\
323320
self._reset_eval_dataloader(model, 'test')

pytorch_lightning/trainer/evaluate_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pytorch_lightning.core.step_result import Result, EvalResult
44
from pytorch_lightning.utilities.exceptions import MisconfigurationException
55
from pytorch_lightning.utilities import flatten_dict
6+
from pytorch_lightning.utilities.model_utils import is_overridden
67

78

89
class EvaluationLoop(object):
@@ -179,15 +180,15 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
179180
user_reduced = False
180181

181182
if self.testing:
182-
if self.trainer.is_overridden('test_epoch_end', model=model):
183+
if is_overridden('test_epoch_end', model=model):
183184
if using_eval_result:
184185
eval_results = self.__gather_epoch_end_eval_results(outputs)
185186

186187
eval_results = model.test_epoch_end(eval_results)
187188
user_reduced = True
188189

189190
else:
190-
if self.trainer.is_overridden('validation_epoch_end', model=model):
191+
if is_overridden('validation_epoch_end', model=model):
191192
if using_eval_result:
192193
eval_results = self.__gather_epoch_end_eval_results(outputs)
193194

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def copy_trainer_model_properties(self, *args):
201201
def get_model(self) -> LightningModule:
202202
"""Warning: this is just empty shell for code implemented in other class."""
203203

204-
@abstractmethod
205-
def is_overridden(self, *args):
206-
"""Warning: this is just empty shell for code implemented in other class."""
207-
208204
@abstractmethod
209205
def transfer_batch_to_gpu(self, *args):
210206
"""Warning: this is just empty shell for code implemented in other class."""

pytorch_lightning/trainer/model_hooks.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,34 +26,6 @@ def is_function_implemented(self, f_name, model=None):
2626
f_op = getattr(model, f_name, None)
2727
return callable(f_op)
2828

29-
def is_overridden(self, method_name: str, model: LightningModule = None) -> bool:
30-
if model is None:
31-
model = self.get_model()
32-
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
33-
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
34-
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
35-
36-
# assert model, 'no model passes'
37-
38-
if not hasattr(model, method_name):
39-
# in case of calling deprecated method
40-
return False
41-
42-
instance_attr = getattr(model, method_name)
43-
if not instance_attr:
44-
return False
45-
super_attr = getattr(super_object, method_name)
46-
47-
# when code pointers are different, it was implemented
48-
if hasattr(instance_attr, 'patch_loader_code'):
49-
# cannot pickle __code__ so cannot verify if PatchDataloader
50-
# exists which shows dataloader methods have been overwritten.
51-
# so, we hack it by using the string representation
52-
is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__)
53-
else:
54-
is_overridden = instance_attr.__code__ is not super_attr.__code__
55-
return is_overridden
56-
5729
def has_arg(self, f_name, arg_name):
5830
model = self.get_model()
5931
f_op = getattr(model, f_name, None)

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from pytorch_lightning.utilities.cloud_io import is_remote_path
5757
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
5858
from pytorch_lightning.trainer.data_connector import DataConnector
59+
from pytorch_lightning.utilities.model_utils import is_overridden
5960

6061
# warnings to ignore in trainer
6162
warnings.filterwarnings(
@@ -902,7 +903,7 @@ def disable_validation(self) -> bool:
902903
@property
903904
def enable_validation(self) -> bool:
904905
""" Check if we should run validation during training. """
905-
val_loop_enabled = self.is_overridden('validation_step') and self.limit_val_batches > 0
906+
val_loop_enabled = is_overridden('validation_step', self.get_model()) and self.limit_val_batches > 0
906907
return val_loop_enabled or self.fast_dev_run
907908

908909
@property
@@ -1091,7 +1092,7 @@ def select_accelerator(self):
10911092

10921093
def can_prepare_data(self):
10931094
should_call_dm_prepare_data = True
1094-
if self.datamodule is not None and self.is_overridden('prepare_data', self.datamodule):
1095+
if self.datamodule is not None and is_overridden('prepare_data', self.datamodule):
10951096
should_call_dm_prepare_data = not self.datamodule.has_prepared_data
10961097

10971098
if self.prepare_data_per_node:
@@ -1197,7 +1198,7 @@ def run_pretrain_routine(self, model: LightningModule):
11971198
self.train()
11981199

11991200
def _run_sanity_check(self, ref_model, model):
1200-
using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
1201+
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', self.get_model())
12011202
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
12021203

12031204
# run tiny validation (if validation defined)
@@ -1418,8 +1419,8 @@ def call_hook(self, hook_name, *args, **kwargs):
14181419

14191420
# next call hook in lightningModule
14201421
output = None
1421-
if self.is_overridden(hook_name):
1422-
model_ref = self.get_model()
1422+
model_ref = self.get_model()
1423+
if is_overridden(hook_name, model_ref):
14231424
hook_fx = getattr(model_ref, hook_name)
14241425
output = hook_fx(*args, **kwargs)
14251426

pytorch_lightning/trainer/training_loop.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def training_step(self, batch, batch_idx):
181181
from pytorch_lightning.utilities.exceptions import MisconfigurationException
182182
from pytorch_lightning.utilities.memory import recursive_detach
183183
from pytorch_lightning.utilities.parsing import AttributeDict
184+
from pytorch_lightning.utilities.model_utils import is_overridden
184185

185186
try:
186187
from apex import amp
@@ -300,10 +301,6 @@ def clip_gradients(self, *args):
300301
def detect_nan_tensors(self, *args):
301302
"""Warning: this is just empty shell for code implemented in other class."""
302303

303-
@abstractmethod
304-
def is_overridden(self, *args):
305-
"""Warning: this is just empty shell for code implemented in other class."""
306-
307304
@abstractmethod
308305
def add_progress_bar_metrics(self, *args):
309306
"""Warning: this is just empty shell for code implemented in other class."""
@@ -572,15 +569,15 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
572569
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
573570

574571
# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
575-
if self.is_overridden('training_epoch_end', model=self.get_model()) or auto_reduce_tng_result:
572+
if is_overridden('training_epoch_end', model=self.get_model()) or auto_reduce_tng_result:
576573
epoch_end_outputs.append(optimizer_idx_outputs)
577574

578575
return epoch_end_outputs
579576

580577
def check_checkpoint_callback(self, should_check_val):
581578
# when no val loop is present or fast-dev-run still need to call checkpoints
582579
# TODO bake this logic into the checkpoint callback
583-
should_activate = not self.is_overridden('validation_step') and not should_check_val
580+
should_activate = not is_overridden('validation_step', self.get_model()) and not should_check_val
584581
if should_activate:
585582
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
586583
[c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]
@@ -642,7 +639,7 @@ def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_sto
642639
# --------------------------
643640
# EPOCH END STEP IF DEFINED
644641
# --------------------------
645-
if self.is_overridden('training_epoch_end', model=model):
642+
if is_overridden('training_epoch_end', model=model):
646643
self.global_step += 1
647644

648645
if is_result_obj:

tests/core/test_datamodules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tests.base import EvalModelTemplate
1010
from tests.base.datamodules import TrialMNISTDataModule
1111
from tests.base.develop_utils import reset_seed
12+
from pytorch_lightning.utilities.model_utils import is_overridden
1213

1314

1415
def test_can_prepare_data(tmpdir):
@@ -348,7 +349,7 @@ def transfer_batch_to_device(self, data, device):
348349
trainer = Trainer()
349350
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
350351
trainer.get_model = MagicMock(return_value=model)
351-
if trainer.is_overridden('transfer_batch_to_device', dm):
352+
if is_overridden('transfer_batch_to_device', dm):
352353
model.transfer_batch_to_device = dm.transfer_batch_to_device
353354

354355
batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)

0 commit comments

Comments
 (0)