|
36 | 36 | from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn |
37 | 37 | from pytorch_lightning.utilities.debugging import InternalDebugger |
38 | 38 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 39 | +from pytorch_lightning.trainer.configuration_validator import ConfigValidator |
39 | 40 |
|
40 | 41 | # warnings to ignore in trainer |
41 | 42 | warnings.filterwarnings( |
@@ -644,6 +645,7 @@ def __init__( |
644 | 645 |
|
645 | 646 | # tracks internal state for debugging |
646 | 647 | self.dev_debugger = InternalDebugger(self) |
| 648 | + self.config_validator = ConfigValidator(self) |
647 | 649 |
|
648 | 650 | # Callback system |
649 | 651 | self.on_init_end() |
@@ -974,18 +976,19 @@ def fit( |
974 | 976 | if hasattr(model, 'hparams'): |
975 | 977 | parsing.clean_namespace(model.hparams) |
976 | 978 |
|
977 | | - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders |
978 | | - if (train_dataloader or val_dataloaders) and datamodule: |
979 | | - raise MisconfigurationException( |
980 | | - 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' |
981 | | - ) |
| 979 | + # if a datamodule comes in as the second arg, then fix it for the user |
| 980 | + if isinstance(train_dataloader, LightningDataModule): |
| 981 | + datamodule = train_dataloader |
| 982 | + train_dataloader = None |
| 983 | + |
| 984 | + self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) |
982 | 985 |
|
983 | 986 | # set up the passed in dataloaders (if needed) |
984 | 987 | self.__attach_dataloaders(model, train_dataloader, val_dataloaders) |
985 | 988 | self.__attach_datamodule(model, datamodule) |
986 | 989 |
|
987 | 990 | # check that model is configured correctly |
988 | | - self.check_model_configuration(model) |
| 991 | + self.config_validator.verify_loop_configurations(model) |
989 | 992 |
|
990 | 993 | # callbacks |
991 | 994 | self.on_fit_start() |
@@ -1256,9 +1259,9 @@ def run_pretrain_routine(self, model: LightningModule): |
1256 | 1259 | self.train() |
1257 | 1260 |
|
1258 | 1261 | def _run_sanity_check(self, ref_model, model): |
1259 | | - should_sanity_check = ( |
1260 | | - self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 |
1261 | | - ) |
| 1262 | + |
| 1263 | + using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step') |
| 1264 | + should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 |
1262 | 1265 |
|
1263 | 1266 | # run tiny validation (if validation defined) |
1264 | 1267 | # to make sure program won't crash during val |
@@ -1448,73 +1451,6 @@ def __test_given_model(self, model, test_dataloaders): |
1448 | 1451 |
|
1449 | 1452 | return results |
1450 | 1453 |
|
1451 | | - def check_model_configuration(self, model: LightningModule): |
1452 | | - r""" |
1453 | | - Checks that the model is configured correctly before training or testing is started. |
1454 | | -
|
1455 | | - Args: |
1456 | | - model: The model to check the configuration. |
1457 | | -
|
1458 | | - """ |
1459 | | - # Check training_step, train_dataloader, configure_optimizer methods |
1460 | | - if not self.testing: |
1461 | | - if not self.is_overridden('training_step', model): |
1462 | | - raise MisconfigurationException( |
1463 | | - 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' |
1464 | | - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' |
1465 | | - ) |
1466 | | - |
1467 | | - if not self.is_overridden('train_dataloader', model): |
1468 | | - raise MisconfigurationException( |
1469 | | - 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' |
1470 | | - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' |
1471 | | - ) |
1472 | | - |
1473 | | - if not self.is_overridden('configure_optimizers', model): |
1474 | | - raise MisconfigurationException( |
1475 | | - 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' |
1476 | | - ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' |
1477 | | - ) |
1478 | | - |
1479 | | - # Check val_dataloader, validation_step and validation_epoch_end |
1480 | | - if self.is_overridden('val_dataloader', model): |
1481 | | - if not self.is_overridden('validation_step', model): |
1482 | | - raise MisconfigurationException( |
1483 | | - 'You have passed in a `val_dataloader()`' ' but have not defined `validation_step()`.' |
1484 | | - ) |
1485 | | - else: |
1486 | | - if not self.is_overridden('validation_epoch_end', model): |
1487 | | - rank_zero_warn( |
1488 | | - 'You have defined a `val_dataloader()` and have defined a `validation_step()`,' |
1489 | | - ' you may also want to define `validation_epoch_end()` for accumulating stats.', |
1490 | | - RuntimeWarning, |
1491 | | - ) |
1492 | | - else: |
1493 | | - if self.is_overridden('validation_step', model): |
1494 | | - raise MisconfigurationException( |
1495 | | - 'You have defined `validation_step()`,' ' but have not passed in a `val_dataloader()`.' |
1496 | | - ) |
1497 | | - |
1498 | | - # Check test_dataloader, test_step and test_epoch_end |
1499 | | - if self.is_overridden('test_dataloader', model): |
1500 | | - if not self.is_overridden('test_step', model): |
1501 | | - raise MisconfigurationException( |
1502 | | - 'You have passed in a `test_dataloader()`' ' but have not defined `test_step()`.' |
1503 | | - ) |
1504 | | - else: |
1505 | | - if not self.is_overridden('test_epoch_end', model): |
1506 | | - rank_zero_warn( |
1507 | | - 'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to' |
1508 | | - ' define `test_epoch_end()` for accumulating stats.', |
1509 | | - RuntimeWarning, |
1510 | | - ) |
1511 | | - else: |
1512 | | - if self.testing and self.is_overridden('test_step', model): |
1513 | | - raise MisconfigurationException( |
1514 | | - 'You have defined `test_step()` but did not' |
1515 | | - ' implement `test_dataloader` nor passed in `.test(test_dataloader)`.' |
1516 | | - ) |
1517 | | - |
1518 | 1454 | def barrier(self, name): |
1519 | 1455 | if self.use_ddp or self.use_ddp2: |
1520 | 1456 | pass |
|
0 commit comments