diff --git a/src/nn_core/callbacks.py b/src/nn_core/callbacks.py index a91d68b..6160470 100644 --- a/src/nn_core/callbacks.py +++ b/src/nn_core/callbacks.py @@ -16,10 +16,16 @@ def _is_nnlogger(trainer: Trainer) -> bool: def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self._is_nnlogger(trainer): + trainer.logger: NNLogger trainer.logger.upload_source() trainer.logger.log_configuration(model=pl_module) trainer.logger.watch_model(pl_module=pl_module) + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if self._is_nnlogger(trainer): + trainer.logger: NNLogger + trainer.logger.upload_run_files() + def on_save_checkpoint( self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any] ) -> None: diff --git a/src/nn_core/model_logging.py b/src/nn_core/model_logging.py index ba843f0..103964b 100644 --- a/src/nn_core/model_logging.py +++ b/src/nn_core/model_logging.py @@ -1,6 +1,7 @@ import argparse import logging import os +import shutil from pathlib import Path from typing import Any, Dict, Optional, Union @@ -55,7 +56,7 @@ def watch_model(self, pl_module: LightningModule): def upload_source(self) -> None: if self.logging_cfg.upload.source and self.wandb: - pylogger.info("Uploading source code to wandb") + pylogger.info("Uploading source code to W&B") self.wrapped.experiment.log_code( root=PROJECT_ROOT, name=None, @@ -201,3 +202,12 @@ def log_configuration( # send hparams to all loggers pylogger.debug("Logging 'cfg'") self.wrapped.log_hyperparams(cfg) + + def upload_run_files(self): + if self.logging_cfg.upload.run_files: + if self.wandb: + pylogger.info("Uploading run files to W&B") + shutil.copytree(self.run_dir, f"{self.wrapped.experiment.dir}/run_files") + + # FIXME: symlink not working for some reason + # os.symlink(self.run_dir, f"{self.wrapped.experiment.dir}/run_files", target_is_directory=True)