From b50ad9ee95b0c08b644f43e6c9bc8cf34ac8f324 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 12 Dec 2020 15:55:11 +0100 Subject: [PATCH] split tests for deprecated api (#5071) * imports * imports * flake8 Co-authored-by: Rohit Gupta --- tests/deprecated_api/__init__.py | 21 ++++ tests/deprecated_api/test_remove_1-2.py | 45 +++++++++ .../test_remove_1-3.py} | 95 +++---------------- 3 files changed, 78 insertions(+), 83 deletions(-) create mode 100644 tests/deprecated_api/__init__.py create mode 100644 tests/deprecated_api/test_remove_1-2.py rename tests/{test_deprecated.py => deprecated_api/test_remove_1-3.py} (60%) diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py new file mode 100644 index 0000000000000..99e21d1ed6b22 --- /dev/null +++ b/tests/deprecated_api/__init__.py @@ -0,0 +1,21 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in vX.Y.Z""" +import sys + + +def _soft_unimport_module(str_module): + # once the module is imported e.g with parsing with pytest it lives in memory + if str_module in sys.modules: + del sys.modules[str_module] diff --git a/tests/deprecated_api/test_remove_1-2.py b/tests/deprecated_api/test_remove_1-2.py new file mode 100644 index 0000000000000..331208d56df10 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-2.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in vX.Y.Z""" + +import pytest +import torch + +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_tbd_remove_in_v1_2_0(): + with pytest.deprecated_call(match='will be removed in v1.2'): + ModelCheckpoint(filepath='..') + + with pytest.deprecated_call(match='will be removed in v1.2'): + ModelCheckpoint('..') + + with pytest.raises(MisconfigurationException, match='inputs which are not feasible'): + ModelCheckpoint(filepath='..', dirpath='.') + + +def test_tbd_remove_in_v1_2_0_metrics(): + from pytorch_lightning.metrics.classification import Fbeta + from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score + + with pytest.deprecated_call(match='will be removed in v1.2'): + Fbeta(2) + + with pytest.deprecated_call(match='will be removed in v1.2'): + fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2) + + with pytest.deprecated_call(match='will be removed in v1.2'): + f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0])) diff --git a/tests/test_deprecated.py b/tests/deprecated_api/test_remove_1-3.py similarity index 60% rename from tests/test_deprecated.py rename to tests/deprecated_api/test_remove_1-3.py index 59c6728009b6f..7ec69796b1e46 100644 --- a/tests/test_deprecated.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" -import sys from argparse import ArgumentParser from unittest import mock @@ -21,10 +20,8 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.metrics.functional.classification import auc from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import EvalModelTemplate def test_tbd_remove_in_v1_3_0(tmpdir): @@ -52,27 +49,27 @@ def __init__(self, hparams): def test_tbd_remove_in_v1_3_0_metrics(): + from pytorch_lightning.metrics.functional.classification import to_onehot with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import to_onehot to_onehot(torch.tensor([1, 2, 3])) + from pytorch_lightning.metrics.functional.classification import to_categorical with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import to_categorical to_categorical(torch.tensor([[0.2, 0.5], [0.9, 0.1]])) + from pytorch_lightning.metrics.functional.classification import get_num_classes with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import get_num_classes get_num_classes(pred=torch.tensor([0, 1]), target=torch.tensor([1, 1])) x_binary = torch.tensor([0, 1, 2, 3]) y_binary = torch.tensor([0, 1, 2, 3]) + from pytorch_lightning.metrics.functional.classification import roc with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import roc roc(pred=x_binary, target=y_binary) + from pytorch_lightning.metrics.functional.classification import _roc with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import _roc _roc(pred=x_binary, target=y_binary) x_multy = torch.tensor([[0.85, 0.05, 0.05, 0.05], @@ -81,64 +78,40 @@ def test_tbd_remove_in_v1_3_0_metrics(): [0.05, 0.05, 0.05, 0.85]]) y_multy = torch.tensor([0, 1, 3, 2]) + from pytorch_lightning.metrics.functional.classification import multiclass_roc with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import multiclass_roc multiclass_roc(pred=x_multy, target=y_multy) + from pytorch_lightning.metrics.functional.classification import average_precision with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import average_precision average_precision(pred=x_binary, target=y_binary) + from pytorch_lightning.metrics.functional.classification import precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import precision_recall_curve precision_recall_curve(pred=x_binary, target=y_binary) + from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.classification import multiclass_precision_recall_curve multiclass_precision_recall_curve(pred=x_multy, target=y_multy) + from pytorch_lightning.metrics.functional.reduction import reduce with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.reduction import reduce reduce(torch.tensor([0, 1, 1, 0]), 'sum') + from pytorch_lightning.metrics.functional.reduction import class_reduce with pytest.deprecated_call(match='will be removed in v1.3'): - from pytorch_lightning.metrics.functional.reduction import class_reduce class_reduce(torch.randint(1, 10, (50,)).float(), torch.randint(10, 20, (50,)).float(), torch.randint(1, 100, (50,)).float()) -def test_tbd_remove_in_v1_2_0(): - with pytest.deprecated_call(match='will be removed in v1.2'): - checkpoint_cb = ModelCheckpoint(filepath='.') - - with pytest.deprecated_call(match='will be removed in v1.2'): - checkpoint_cb = ModelCheckpoint('.') - - with pytest.raises(MisconfigurationException, match='inputs which are not feasible'): - checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.') - - -def test_tbd_remove_in_v1_2_0_metrics(): - from pytorch_lightning.metrics.classification import Fbeta - from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score - - with pytest.deprecated_call(match='will be removed in v1.2'): - Fbeta(2) - - with pytest.deprecated_call(match='will be removed in v1.2'): - fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2) - - with pytest.deprecated_call(match='will be removed in v1.2'): - f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0])) - - # TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py @pytest.mark.parametrize(['profiler', 'expected'], [ (True, SimpleProfiler), (False, PassThroughProfiler), ]) def test_trainer_profiler_remove_in_v1_3_0(profiler, expected): + # remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py with pytest.deprecated_call(match='will be removed in v1.3'): trainer = Trainer(profiler=profiler) assert isinstance(trainer.profiler, expected) @@ -162,47 +135,3 @@ def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, ex assert getattr(args, "profiler") == expected_parsed_arg trainer = Trainer.from_argparse_args(args) assert isinstance(trainer.profiler, expected_profiler) - - -def _soft_unimport_module(str_module): - # once the module is imported e.g with parsing with pytest it lives in memory - if str_module in sys.modules: - del sys.modules[str_module] - - -class ModelVer0_6(EvalModelTemplate): - - # todo: this shall not be needed while evaluate asks for dataloader explicitly - def val_dataloader(self): - return self.dataloader(train=False) - - def validation_step(self, batch, batch_idx, *args, **kwargs): - return {'val_loss': torch.tensor(0.6)} - - def validation_end(self, outputs): - return {'val_loss': torch.tensor(0.6)} - - def test_dataloader(self): - return self.dataloader(train=False) - - def test_end(self, outputs): - return {'test_loss': torch.tensor(0.6)} - - -class ModelVer0_7(EvalModelTemplate): - - # todo: this shall not be needed while evaluate asks for dataloader explicitly - def val_dataloader(self): - return self.dataloader(train=False) - - def validation_step(self, batch, batch_idx, *args, **kwargs): - return {'val_loss': torch.tensor(0.7)} - - def validation_end(self, outputs): - return {'val_loss': torch.tensor(0.7)} - - def test_dataloader(self): - return self.dataloader(train=False) - - def test_end(self, outputs): - return {'test_loss': torch.tensor(0.7)}