Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: nn-core
name: nn-template-core
channels:
- defaults
- pytorch
Expand Down
4 changes: 3 additions & 1 deletion src/nn_core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning.loggers import LightningLoggerBase

from nn_core.common import PROJECT_ROOT
from nn_core.model_logging import NNLogger

pylogger = logging.getLogger(__name__)

Expand All @@ -28,7 +29,8 @@ def __init__(self, upload: Optional[Dict[str, bool]], logger: Optional[DictConfi
self.wandb: bool = self.logger_cfg["_target_"].endswith("WandbLogger")

def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
trainer.logger.log_configuration(model=pl_module)
if isinstance(trainer.logger, NNLogger):
trainer.logger.log_configuration(model=pl_module)

if "wandb_watch" in self.kwargs:
trainer.logger.wrapped.watch(pl_module, **self.kwargs["wandb_watch"])
Expand Down
15 changes: 8 additions & 7 deletions src/nn_core/model_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, logger: Optional[LightningLoggerBase], storage_dir: str, cfg)
self.cfg = cfg

def __getattr__(self, item):
return getattr(self.wrapped, item)
if self.wrapped is not None:
return getattr(self.wrapped, item)

@property
def save_dir(self) -> Optional[str]:
Expand Down Expand Up @@ -110,16 +111,16 @@ def log_configuration(
if isinstance(cfg, DictConfig):
cfg: Union[Dict[str, Any], argparse.Namespace, DictConfig] = OmegaConf.to_container(cfg, resolve=True)

# Store the YaML config separately into the wandb dir
yaml_conf: str = OmegaConf.to_yaml(cfg=cfg)
run_dir: Path = Path(self.run_dir)
run_dir.mkdir(exist_ok=True, parents=True)
(run_dir / "config.yaml").write_text(yaml_conf)

# save number of model parameters
cfg[f"{_STATS_KEY}/params_total"] = sum(p.numel() for p in model.parameters())
cfg[f"{_STATS_KEY}/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
cfg[f"{_STATS_KEY}/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)

# send hparams to all loggers
self.wrapped.log_hyperparams(cfg)

# Store the YaML config separately into the wandb dir
yaml_conf: str = OmegaConf.to_yaml(cfg=cfg)
run_dir: Path = Path(self.run_dir)
run_dir.mkdir(exist_ok=True, parents=True)
(run_dir / "config.yaml").write_text(yaml_conf)