From 1b7be4d5dc47b50b71be98e93adea318b14d0a7e Mon Sep 17 00:00:00 2001 From: Luca Moschella Date: Sat, 8 Jan 2022 23:55:31 +0100 Subject: [PATCH] Improve compatibility with the template --- env.yaml | 2 +- src/nn_core/callbacks.py | 4 +++- src/nn_core/model_logging.py | 15 ++++++++------- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/env.yaml b/env.yaml index e76c9c0..4f6423e 100644 --- a/env.yaml +++ b/env.yaml @@ -1,4 +1,4 @@ -name: nn-core +name: nn-template-core channels: - defaults - pytorch diff --git a/src/nn_core/callbacks.py b/src/nn_core/callbacks.py index 07c232d..f1f180a 100644 --- a/src/nn_core/callbacks.py +++ b/src/nn_core/callbacks.py @@ -9,6 +9,7 @@ from pytorch_lightning.loggers import LightningLoggerBase from nn_core.common import PROJECT_ROOT +from nn_core.model_logging import NNLogger pylogger = logging.getLogger(__name__) @@ -28,7 +29,8 @@ def __init__(self, upload: Optional[Dict[str, bool]], logger: Optional[DictConfi self.wandb: bool = self.logger_cfg["_target_"].endswith("WandbLogger") def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - trainer.logger.log_configuration(model=pl_module) + if isinstance(trainer.logger, NNLogger): + trainer.logger.log_configuration(model=pl_module) if "wandb_watch" in self.kwargs: trainer.logger.wrapped.watch(pl_module, **self.kwargs["wandb_watch"]) diff --git a/src/nn_core/model_logging.py b/src/nn_core/model_logging.py index e936c7d..b699caf 100644 --- a/src/nn_core/model_logging.py +++ b/src/nn_core/model_logging.py @@ -21,7 +21,8 @@ def __init__(self, logger: Optional[LightningLoggerBase], storage_dir: str, cfg) self.cfg = cfg def __getattr__(self, item): - return getattr(self.wrapped, item) + if self.wrapped is not None: + return getattr(self.wrapped, item) @property def save_dir(self) -> Optional[str]: @@ -110,6 +111,12 @@ def log_configuration( if isinstance(cfg, DictConfig): cfg: Union[Dict[str, Any], argparse.Namespace, DictConfig] = OmegaConf.to_container(cfg, resolve=True) + # Store the YaML config separately into the wandb dir + yaml_conf: str = OmegaConf.to_yaml(cfg=cfg) + run_dir: Path = Path(self.run_dir) + run_dir.mkdir(exist_ok=True, parents=True) + (run_dir / "config.yaml").write_text(yaml_conf) + # save number of model parameters cfg[f"{_STATS_KEY}/params_total"] = sum(p.numel() for p in model.parameters()) cfg[f"{_STATS_KEY}/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -117,9 +124,3 @@ def log_configuration( # send hparams to all loggers self.wrapped.log_hyperparams(cfg) - - # Store the YaML config separately into the wandb dir - yaml_conf: str = OmegaConf.to_yaml(cfg=cfg) - run_dir: Path = Path(self.run_dir) - run_dir.mkdir(exist_ok=True, parents=True) - (run_dir / "config.yaml").write_text(yaml_conf)