Skip to content

Commit 2f6d82e

Browse files
ref: remove on_eval_start hook (Lightning-AI#3176)
* remove on_eval_start hook * remove on_eval_start hook
1 parent 59fb332 commit 2f6d82e

File tree

7 files changed

+25
-53
lines changed

7 files changed

+25
-53
lines changed

pytorch_lightning/callbacks/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,10 @@ def on_pretrain_routine_end(self, trainer, pl_module):
137137
"""Called when the pretrain routine ends."""
138138
pass
139139

140-
def on_validation_start(self, trainer, pl_module):
141-
"""Called when the validation loop begins."""
142-
pass
143-
144140
def on_validation_end(self, trainer, pl_module):
145141
"""Called when the validation loop ends."""
146142
pass
147143

148-
def on_test_start(self, trainer, pl_module):
149-
"""Called when the test begins."""
150-
pass
151-
152144
def on_test_end(self, trainer, pl_module):
153145
"""Called when the test ends."""
154146
pass

pytorch_lightning/callbacks/progress.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,13 @@ def on_epoch_start(self, trainer, pl_module):
155155
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
156156
self._train_batch_idx += 1
157157

158-
def on_validation_start(self, trainer, pl_module):
158+
def on_validation_epoch_start(self, trainer, pl_module):
159159
self._val_batch_idx = 0
160160

161161
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
162162
self._val_batch_idx += 1
163163

164-
def on_test_start(self, trainer, pl_module):
164+
def on_test_epoch_start(self, trainer, pl_module):
165165
self._test_batch_idx = 0
166166

167167
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
@@ -338,8 +338,8 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id
338338
self.main_progress_bar.update(self.refresh_rate)
339339
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
340340

341-
def on_validation_start(self, trainer, pl_module):
342-
super().on_validation_start(trainer, pl_module)
341+
def on_validation_epoch_start(self, trainer, pl_module):
342+
super().on_validation_epoch_start(trainer, pl_module)
343343
self.val_progress_bar = self.init_validation_tqdm()
344344
self.val_progress_bar.total = convert_inf(self.total_val_batches)
345345

@@ -358,8 +358,8 @@ def on_train_end(self, trainer, pl_module):
358358
super().on_train_end(trainer, pl_module)
359359
self.main_progress_bar.close()
360360

361-
def on_test_start(self, trainer, pl_module):
362-
super().on_test_start(trainer, pl_module)
361+
def on_test_epoch_start(self, trainer, pl_module):
362+
super().on_test_epoch_start(trainer, pl_module)
363363
self.test_progress_bar = self.init_test_tqdm()
364364
self.test_progress_bar.total = convert_inf(self.total_test_batches)
365365

pytorch_lightning/trainer/callback_hook.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,11 @@ def on_test_batch_end(self, batch, batch_idx, dataloader_idx):
165165
for callback in self.callbacks:
166166
callback.on_test_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)
167167

168-
def on_validation_start(self):
169-
"""Called when the validation loop begins."""
170-
for callback in self.callbacks:
171-
callback.on_validation_start(self, self.get_model())
172-
173168
def on_validation_end(self):
174169
"""Called when the validation loop ends."""
175170
for callback in self.callbacks:
176171
callback.on_validation_end(self, self.get_model())
177172

178-
def on_test_start(self):
179-
"""Called when the test begins."""
180-
for callback in self.callbacks:
181-
callback.on_test_start(self, self.get_model())
182-
183173
def on_test_end(self):
184174
"""Called when the test ends."""
185175
for callback in self.callbacks:

pytorch_lightning/trainer/evaluate_loop.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,6 @@ def should_skip_evaluation(self, dataloaders, max_batches):
4545

4646
return False
4747

48-
def on_evaluation_start(self, *args, **kwargs):
49-
if self.testing:
50-
self.trainer.call_hook('on_test_start', *args, **kwargs)
51-
else:
52-
self.trainer.call_hook('on_validation_start', *args, **kwargs)
53-
5448
def on_evaluation_end(self, *args, **kwargs):
5549
if self.testing:
5650
self.trainer.call_hook('on_test_end', *args, **kwargs)

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,7 @@ class TrainerEvaluationLoopMixin(ABC):
186186
on_validation_batch_end: Callable
187187
on_test_batch_start: Callable
188188
on_test_batch_end: Callable
189-
on_validation_start: Callable
190189
on_validation_end: Callable
191-
on_test_start: Callable
192190
on_test_end: Callable
193191
accelerator_backend: ...
194192
evaluation_loop: EvaluationLoop
@@ -311,8 +309,10 @@ def run_evaluation(self, test_mode: bool = False):
311309
# set up the loop for val/test
312310
self.evaluation_loop.testing = test_mode
313311

314-
# TODO: deprecate
312+
# enable eval mode + no grads
315313
model = self.get_model()
314+
315+
# TODO: deprecate
316316
model.on_pre_performance_check()
317317

318318
# select dataloaders
@@ -322,13 +322,9 @@ def run_evaluation(self, test_mode: bool = False):
322322
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
323323
return [], []
324324

325-
# TODO: deprecate
326-
self.evaluation_loop.on_evaluation_start()
327-
328325
# ------------------------------
329326
# ------------------------------
330327
# ------------------------------
331-
# enable eval mode + no grads
332328
model.zero_grad()
333329
model.eval()
334330
torch.set_grad_enabled(False)

tests/callbacks/test_callbacks.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def __init__(self):
3838
self.on_train_end_called = False
3939
self.on_pretrain_routine_start_called = False
4040
self.on_pretrain_routine_end_called = False
41-
self.on_validation_start_called = False
41+
self.on_validation_epoch_start_called = False
4242
self.on_validation_end_called = False
43-
self.on_test_start_called = False
43+
self.on_test_epoch_start_called = False
4444
self.on_test_end_called = False
4545

4646
def setup(self, trainer, pl_module, stage: str):
@@ -131,17 +131,17 @@ def on_pretrain_routine_end(self, trainer, pl_module):
131131
_check_args(trainer, pl_module)
132132
self.on_pretrain_routine_end_called = True
133133

134-
def on_validation_start(self, trainer, pl_module):
134+
def on_validation_epoch_start(self, trainer, pl_module):
135135
_check_args(trainer, pl_module)
136-
self.on_validation_start_called = True
136+
self.on_validation_epoch_start_called = True
137137

138138
def on_validation_end(self, trainer, pl_module):
139139
_check_args(trainer, pl_module)
140140
self.on_validation_end_called = True
141141

142-
def on_test_start(self, trainer, pl_module):
142+
def on_test_epoch_start(self, trainer, pl_module):
143143
_check_args(trainer, pl_module)
144-
self.on_test_start_called = True
144+
self.on_test_epoch_start_called = True
145145

146146
def on_test_end(self, trainer, pl_module):
147147
_check_args(trainer, pl_module)
@@ -180,9 +180,9 @@ def on_test_end(self, trainer, pl_module):
180180
assert not test_callback.on_train_end_called
181181
assert not test_callback.on_pretrain_routine_start_called
182182
assert not test_callback.on_pretrain_routine_end_called
183-
assert not test_callback.on_validation_start_called
183+
assert not test_callback.on_validation_epoch_start_called
184184
assert not test_callback.on_validation_end_called
185-
assert not test_callback.on_test_start_called
185+
assert not test_callback.on_test_epoch_start_called
186186
assert not test_callback.on_test_end_called
187187

188188
# fit model
@@ -211,9 +211,9 @@ def on_test_end(self, trainer, pl_module):
211211
assert not test_callback.on_train_end_called
212212
assert not test_callback.on_pretrain_routine_start_called
213213
assert not test_callback.on_pretrain_routine_end_called
214-
assert not test_callback.on_validation_start_called
214+
assert not test_callback.on_validation_epoch_start_called
215215
assert not test_callback.on_validation_end_called
216-
assert not test_callback.on_test_start_called
216+
assert not test_callback.on_test_epoch_start_called
217217
assert not test_callback.on_test_end_called
218218

219219
trainer.fit(model)
@@ -238,11 +238,11 @@ def on_test_end(self, trainer, pl_module):
238238
assert test_callback.on_train_end_called
239239
assert test_callback.on_pretrain_routine_start_called
240240
assert test_callback.on_pretrain_routine_end_called
241-
assert test_callback.on_validation_start_called
241+
assert test_callback.on_validation_epoch_start_called
242242
assert test_callback.on_validation_end_called
243243
assert not test_callback.on_test_batch_start_called
244244
assert not test_callback.on_test_batch_end_called
245-
assert not test_callback.on_test_start_called
245+
assert not test_callback.on_test_epoch_start_called
246246
assert not test_callback.on_test_end_called
247247

248248
# reset setup teardown callback
@@ -258,9 +258,9 @@ def on_test_end(self, trainer, pl_module):
258258
assert test_callback.teardown_called
259259
assert test_callback.on_test_batch_start_called
260260
assert test_callback.on_test_batch_end_called
261-
assert test_callback.on_test_start_called
261+
assert test_callback.on_test_epoch_start_called
262262
assert test_callback.on_test_end_called
263-
assert not test_callback.on_validation_start_called
263+
assert not test_callback.on_validation_epoch_start_called
264264
assert not test_callback.on_validation_end_called
265265
assert not test_callback.on_validation_batch_end_called
266266
assert not test_callback.on_validation_batch_start_called

tests/trainer/test_dataloaders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,12 +704,12 @@ def on_train_start(self, trainer, pl_module):
704704
assert isinstance(train_sampler, DistributedSampler)
705705
assert train_sampler.shuffle
706706

707-
def on_validation_start(self, trainer, pl_module):
707+
def on_validation_epoch_start(self, trainer, pl_module):
708708
val_sampler = trainer.val_dataloaders[0].sampler
709709
assert isinstance(val_sampler, DistributedSampler)
710710
assert not val_sampler.shuffle
711711

712-
def on_test_start(self, trainer, pl_module):
712+
def on_test_epoch_start(self, trainer, pl_module):
713713
test_sampler = trainer.test_dataloaders[0].sampler
714714
assert isinstance(test_sampler, DistributedSampler)
715715
assert not test_sampler.shuffle

0 commit comments

Comments
 (0)