Skip to content

Commit

Permalink
update usage of deprecated checkpoint_callback (Lightning-AI#5006)
Browse files Browse the repository at this point in the history
* drop usage of deprecated checkpoint_callback

* fix

* fix
  • Loading branch information
Borda committed Dec 9, 2020
1 parent ce91795 commit 05f25f3
Show file tree
Hide file tree
Showing 12 changed files with 37 additions and 29 deletions.
7 changes: 6 additions & 1 deletion tests/backends/test_tpu_backend.py
Expand Up @@ -39,7 +39,12 @@ def test_resume_training_on_cpu(tmpdir):
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir)
trainer = Trainer(
resume_from_checkpoint=model_path,
checkpoint_callback=True,
max_epochs=1,
default_root_dir=tmpdir,
)
result = trainer.fit(model)

assert result == 1
Expand Down
3 changes: 1 addition & 2 deletions tests/callbacks/test_early_stopping.py
Expand Up @@ -56,8 +56,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
early_stop_callback = EarlyStoppingTestRestore()
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
callbacks=[early_stop_callback],
callbacks=[early_stop_callback, checkpoint_callback],
num_sanity_val_steps=0,
max_epochs=4,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_checkpoint_callback_frequency.py
Expand Up @@ -129,7 +129,7 @@ def training_step(self, batch, batch_idx):

model = TestModel()
trainer = Trainer(
checkpoint_callback=callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k),
callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss', save_top_k=k)],
default_root_dir=tmpdir,
max_epochs=epochs,
weights_summary=None,
Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_model_checkpoint.py
Expand Up @@ -897,7 +897,7 @@ def test_configure_model_checkpoint(tmpdir):
assert trainer.checkpoint_callbacks == [callback1, callback2]

with pytest.warns(DeprecationWarning, match='will no longer be supported in v1.3'):
trainer = Trainer(checkpoint_callback=callback1, callbacks=[], **kwargs)
trainer = Trainer(checkpoint_callback=callback1, **kwargs)
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
assert trainer.checkpoint_callback == callback1

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_datamodules.py
Expand Up @@ -243,7 +243,7 @@ def test_dm_checkpoint_save(tmpdir):
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on')
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on')],
)

# fit model
Expand Down
2 changes: 1 addition & 1 deletion tests/models/data/horovod/train_default_model.py
Expand Up @@ -49,7 +49,7 @@ def run_test_from_config(trainer_options):
reset_seed()

ckpt_path = trainer_options['weights_save_path']
trainer_options.update(checkpoint_callback=ModelCheckpoint(dirpath=ckpt_path))
trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)])

model = EvalModelTemplate()

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_amp.py
Expand Up @@ -129,7 +129,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
gpus=[0],
accelerator='ddp_spawn',
precision=16,
checkpoint_callback=checkpoint,
callbacks=[checkpoint],
logger=logger,
)
trainer.is_slurm_managing_tasks = True
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_cpu.py
Expand Up @@ -43,7 +43,7 @@ def test_cpu_slurm_save_load(enable_pl_optimizer, tmpdir):
logger=logger,
limit_train_batches=0.2,
limit_val_batches=0.2,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
enable_pl_optimizer=enable_pl_optimizer,
)
result = trainer.fit(model)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_cpu_slurm_save_load(enable_pl_optimizer, tmpdir):
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
enable_pl_optimizer=enable_pl_optimizer,
)
model = EvalModelTemplate(**hparams)
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_running_test_after_fitting(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
limit_test_batches=0.2,
checkpoint_callback=checkpoint,
callbacks=[checkpoint],
logger=logger,
)
result = trainer.fit(model)
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_running_test_no_val(tmpdir):
limit_train_batches=0.4,
limit_val_batches=0.2,
limit_test_batches=0.2,
checkpoint_callback=checkpoint,
callbacks=[checkpoint],
logger=logger,
)
result = trainer.fit(model)
Expand Down
18 changes: 10 additions & 8 deletions tests/models/test_restore.py
Expand Up @@ -159,7 +159,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
max_epochs=2,
limit_train_batches=0.4,
limit_val_batches=0.2,
checkpoint_callback=checkpoint,
callbacks=[checkpoint],
logger=logger,
gpus=[0, 1],
accelerator='dp',
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
max_epochs=2,
limit_train_batches=0.4,
limit_val_batches=0.2,
checkpoint_callback=checkpoint,
callbacks=[checkpoint],
logger=logger,
gpus=[0, 1],
accelerator='ddp_spawn',
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
max_epochs=3,
limit_train_batches=0.4,
limit_val_batches=0.2,
checkpoint_callback=checkpoint,
callbacks=[checkpoint],
logger=logger,
default_root_dir=tmpdir,
)
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template):
max_epochs=2,
limit_train_batches=0.4,
limit_val_batches=0.2,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=-1),
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=-1)],
default_root_dir=tmpdir,
)

Expand Down Expand Up @@ -404,8 +404,10 @@ def test_model_saving_loading(tmpdir):

# fit model
trainer = Trainer(
max_epochs=1, logger=logger,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir), default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
default_root_dir=tmpdir,
)
result = trainer.fit(model)

Expand Down Expand Up @@ -460,7 +462,7 @@ def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_c
# fit model
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, logger=logger,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
)
result = trainer.fit(model)

Expand Down Expand Up @@ -500,7 +502,7 @@ def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_c
# fit model
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, logger=logger,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
)
result = trainer.fit(model)

Expand Down
5 changes: 3 additions & 2 deletions tests/trainer/logging_tests/test_eval_loop_logging_1_0.py
Expand Up @@ -26,6 +26,7 @@
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import BoringModel, RandomDataset, SimpleModule
Expand Down Expand Up @@ -291,7 +292,7 @@ def validation_epoch_end(self, outputs) -> None:
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
checkpoint_callback=callbacks.ModelCheckpoint(dirpath='val_loss')
callbacks=[ModelCheckpoint(dirpath='val_loss')],
)
trainer.fit(model)

Expand Down Expand Up @@ -358,7 +359,7 @@ def test_monitor_val_epoch_end(tmpdir):
trainer = Trainer(
max_epochs=epoch_min_loss_override + 2,
logger=False,
checkpoint_callback=checkpoint_callback,
callbacks=[checkpoint_callback],
)
trainer.fit(model)

Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Expand Up @@ -27,6 +27,7 @@

import pytorch_lightning as pl
from pytorch_lightning import Trainer, callbacks
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
from tests.base.deterministic_model import DeterministicModel
Expand Down Expand Up @@ -88,7 +89,7 @@ def backward(self, loss, optimizer, optimizer_idx):
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
checkpoint_callback=callbacks.ModelCheckpoint(monitor='l_se')
callbacks=[ModelCheckpoint(monitor='l_se')],
)
trainer.fit(model)

Expand Down
12 changes: 6 additions & 6 deletions tests/trainer/test_trainer.py
Expand Up @@ -55,7 +55,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
)
# fit model
result = trainer.fit(model)
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
)
result = trainer.fit(model)

Expand Down Expand Up @@ -145,7 +145,7 @@ def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir),
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
)
result = trainer.fit(model)

Expand Down Expand Up @@ -462,7 +462,7 @@ def test_model_checkpoint_only_weights(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_weights_only=True),
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_weights_only=True)],
)
# fit model
result = trainer.fit(model)
Expand Down Expand Up @@ -539,7 +539,7 @@ def increment_on_load_checkpoint(self, _):
max_epochs=2,
limit_train_batches=0.65,
limit_val_batches=1,
checkpoint_callback=ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=-1),
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor='early_stop_on', save_top_k=-1)],
default_root_dir=tmpdir,
val_check_interval=1.0,
enable_pl_optimizer=enable_pl_optimizer,
Expand Down Expand Up @@ -718,7 +718,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
max_epochs=2,
progress_bar_refresh_rate=0,
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k),
callbacks=[ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k)],
)
trainer.fit(model)
if ckpt_path == "best":
Expand Down

0 comments on commit 05f25f3

Please sign in to comment.