@@ -181,6 +181,7 @@ def training_step(self, batch, batch_idx):
181181from pytorch_lightning .utilities .exceptions import MisconfigurationException
182182from pytorch_lightning .utilities .memory import recursive_detach
183183from pytorch_lightning .utilities .parsing import AttributeDict
184+ from pytorch_lightning .utilities .model_utils import is_overridden
184185
185186try :
186187 from apex import amp
@@ -300,10 +301,6 @@ def clip_gradients(self, *args):
300301 def detect_nan_tensors (self , * args ):
301302 """Warning: this is just empty shell for code implemented in other class."""
302303
303- @abstractmethod
304- def is_overridden (self , * args ):
305- """Warning: this is just empty shell for code implemented in other class."""
306-
307304 @abstractmethod
308305 def add_progress_bar_metrics (self , * args ):
309306 """Warning: this is just empty shell for code implemented in other class."""
@@ -572,15 +569,15 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
572569 auto_reduce_tng_result = isinstance (sample_output , Result ) and sample_output .should_reduce_on_epoch_end
573570
574571 # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
575- if self . is_overridden ('training_epoch_end' , model = self .get_model ()) or auto_reduce_tng_result :
572+ if is_overridden ('training_epoch_end' , model = self .get_model ()) or auto_reduce_tng_result :
576573 epoch_end_outputs .append (optimizer_idx_outputs )
577574
578575 return epoch_end_outputs
579576
580577 def check_checkpoint_callback (self , should_check_val ):
581578 # when no val loop is present or fast-dev-run still need to call checkpoints
582579 # TODO bake this logic into the checkpoint callback
583- should_activate = not self . is_overridden ('validation_step' ) and not should_check_val
580+ should_activate = not is_overridden ('validation_step' , self . get_model () ) and not should_check_val
584581 if should_activate :
585582 checkpoint_callbacks = [c for c in self .callbacks if isinstance (c , ModelCheckpoint )]
586583 [c .on_validation_end (self , self .get_model ()) for c in checkpoint_callbacks ]
@@ -642,7 +639,7 @@ def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_sto
642639 # --------------------------
643640 # EPOCH END STEP IF DEFINED
644641 # --------------------------
645- if self . is_overridden ('training_epoch_end' , model = model ):
642+ if is_overridden ('training_epoch_end' , model = model ):
646643 self .global_step += 1
647644
648645 if is_result_obj :
0 commit comments