From 79f9fb12debd7ede2667cda7b93b3481d71b0c83 Mon Sep 17 00:00:00 2001 From: Luca Moschella Date: Fri, 21 Jan 2022 17:22:42 +0100 Subject: [PATCH 1/3] Move functions from template to core (#9) * Move parse_restore to nn-core * Move enforce_tags to nn-core * Move seed_index logic to nn-core --- src/nn_core/common/utils.py | 45 ++++++++++++++++++++++++++++--- src/nn_core/resume.py | 54 ++++++++++++++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/src/nn_core/common/utils.py b/src/nn_core/common/utils.py index cb3a9b2..6538e7f 100644 --- a/src/nn_core/common/utils.py +++ b/src/nn_core/common/utils.py @@ -1,10 +1,15 @@ import logging import os -from typing import Optional +from typing import List, Optional import dotenv +import numpy as np +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig +from pytorch_lightning import seed_everything +from rich.prompt import Prompt -logger = logging.getLogger(__name__) +pylogger = logging.getLogger(__name__) def get_env(env_name: str, default: Optional[str] = None) -> str: @@ -19,13 +24,17 @@ def get_env(env_name: str, default: Optional[str] = None) -> str: """ if env_name not in os.environ: if default is None: - raise KeyError(f"{env_name} not defined and no default value is present!") + message = f"{env_name} not defined and no default value is present!" + pylogger.error(message) + raise KeyError(message) return default env_value: str = os.environ[env_name] if not env_value: if default is None: - raise ValueError(f"{env_name} has yet to be configured and no default value is present!") + message = f"{env_name} has yet to be configured and no default value is present!" + pylogger.error(message) + raise ValueError(message) return default return env_value @@ -42,3 +51,31 @@ def load_envs(env_file: Optional[str] = None) -> None: it searches for a `.env` file in the project. """ dotenv.load_dotenv(dotenv_path=env_file, override=True) + + +def enforce_tags(tags: Optional[List[str]]) -> List[str]: + if tags is None: + if "id" in HydraConfig().cfg.hydra.job: + # We are in multi-run setting (either via a sweep or a scheduler) + message: str = "You need to specify 'core.tags' in a multi-run setting!" + pylogger.error(message) + raise ValueError(message) + + pylogger.warning("No tags provided, asking for tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="develop") + tags = [x.strip() for x in tags.split(",")] + + pylogger.info(f"Tags: {tags if tags is not None else []}") + return tags + + +def seed_index_everything(train_cfg: DictConfig) -> None: + if "seed_index" in train_cfg and train_cfg.seed_index is not None: + seed_index = train_cfg.seed_index + seed_everything(42) + seeds = np.random.randint(np.iinfo(np.int32).max, size=max(42, seed_index + 1)) + seed = seeds[seed_index] + seed_everything(seed) + pylogger.info(f"Setting seed {seed} from seeds[{seed_index}]") + else: + pylogger.warning("The seed has not been set! The reproducibility is not guaranteed.") diff --git a/src/nn_core/resume.py b/src/nn_core/resume.py index 1ca79d9..7abb9f2 100644 --- a/src/nn_core/resume.py +++ b/src/nn_core/resume.py @@ -1,13 +1,33 @@ +import logging import re +from operator import xor from pathlib import Path -from typing import Optional +from typing import Optional, Tuple import torch import wandb +from omegaconf import DictConfig from wandb.apis.public import Run +pylogger = logging.getLogger(__name__) + RUN_PATH_PATTERN = re.compile(r"^([^/]+)/([^/]+)/([^/]+)$") +RESUME_MODES = { + "continue": { + "restore_model": True, + "restore_run": True, + }, + "hotstart": { + "restore_model": True, + "restore_run": False, + }, + None: { + "restore_model": False, + "restore_run": False, + }, +} + def resolve_ckpt(ckpt_or_run_path: str) -> str: """Resolve the run path or ckpt to a checkpoint. @@ -61,3 +81,35 @@ def resolve_run_version(ckpt_or_run_path: Optional[str] = None, run_path: Option if run_path is None: run_path = resolve_run_path(ckpt_or_run_path) return RUN_PATH_PATTERN.match(run_path).group(3) + + +def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]]: + ckpt_or_run_path = restore_cfg.ckpt_or_run_path + resume_mode = restore_cfg.mode + + resume_ckpt_path = None + resume_run_version = None + + if xor(bool(ckpt_or_run_path), bool(resume_mode)): + pylogger.warning(f"Inconsistent resume modality {resume_mode} and checkpoint path '{ckpt_or_run_path}'") + + if resume_mode not in RESUME_MODES: + message = f"Unsupported resume mode {resume_mode}. Available resume modes are: {RESUME_MODES}" + pylogger.error(message) + raise ValueError(message) + + flags = RESUME_MODES[resume_mode] + restore_model = flags["restore_model"] + restore_run = flags["restore_run"] + + if ckpt_or_run_path is not None: + if restore_model: + resume_ckpt_path = resolve_ckpt(ckpt_or_run_path) + pylogger.info(f"Resume training from: '{resume_ckpt_path}'") + + if restore_run: + run_path = resolve_run_path(ckpt_or_run_path) + resume_run_version = resolve_run_version(run_path=run_path) + pylogger.info(f"Resume logging to: '{run_path}'") + + return resume_ckpt_path, resume_run_version From 90a1737a91beda7c8e0807300e50fd90ec9ff9a6 Mon Sep 17 00:00:00 2001 From: Luca Moschella Date: Fri, 21 Jan 2022 17:28:37 +0100 Subject: [PATCH 2/3] Add upload_run_files functionality (#10) --- src/nn_core/callbacks.py | 6 ++++++ src/nn_core/model_logging.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/nn_core/callbacks.py b/src/nn_core/callbacks.py index a91d68b..6160470 100644 --- a/src/nn_core/callbacks.py +++ b/src/nn_core/callbacks.py @@ -16,10 +16,16 @@ def _is_nnlogger(trainer: Trainer) -> bool: def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self._is_nnlogger(trainer): + trainer.logger: NNLogger trainer.logger.upload_source() trainer.logger.log_configuration(model=pl_module) trainer.logger.watch_model(pl_module=pl_module) + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if self._is_nnlogger(trainer): + trainer.logger: NNLogger + trainer.logger.upload_run_files() + def on_save_checkpoint( self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any] ) -> None: diff --git a/src/nn_core/model_logging.py b/src/nn_core/model_logging.py index ba843f0..103964b 100644 --- a/src/nn_core/model_logging.py +++ b/src/nn_core/model_logging.py @@ -1,6 +1,7 @@ import argparse import logging import os +import shutil from pathlib import Path from typing import Any, Dict, Optional, Union @@ -55,7 +56,7 @@ def watch_model(self, pl_module: LightningModule): def upload_source(self) -> None: if self.logging_cfg.upload.source and self.wandb: - pylogger.info("Uploading source code to wandb") + pylogger.info("Uploading source code to W&B") self.wrapped.experiment.log_code( root=PROJECT_ROOT, name=None, @@ -201,3 +202,12 @@ def log_configuration( # send hparams to all loggers pylogger.debug("Logging 'cfg'") self.wrapped.log_hyperparams(cfg) + + def upload_run_files(self): + if self.logging_cfg.upload.run_files: + if self.wandb: + pylogger.info("Uploading run files to W&B") + shutil.copytree(self.run_dir, f"{self.wrapped.experiment.dir}/run_files") + + # FIXME: symlink not working for some reason + # os.symlink(self.run_dir, f"{self.wrapped.experiment.dir}/run_files", target_is_directory=True) From ba8be70634dc327a3a3b706af5e7cf7fad2f9a47 Mon Sep 17 00:00:00 2001 From: Luca Moschella Date: Fri, 21 Jan 2022 17:43:14 +0100 Subject: [PATCH 3/3] Move ui_utils entirely to nn-core (#11) * Move ui_utils entirely to nn-core Co-authored-by: Valentino Maiorca --- src/nn_core/ui.py | 108 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 src/nn_core/ui.py diff --git a/src/nn_core/ui.py b/src/nn_core/ui.py new file mode 100644 index 0000000..2fc7425 --- /dev/null +++ b/src/nn_core/ui.py @@ -0,0 +1,108 @@ +import datetime +import operator +from pathlib import Path +from typing import List + +import hydra +import omegaconf +import streamlit as st +import wandb +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose +from stqdm import stqdm + +from nn_core.common import PROJECT_ROOT + +WANDB_DIR: Path = PROJECT_ROOT / "wandb" +WANDB_DIR.mkdir(exist_ok=True, parents=True) + +st_run_sel = st.sidebar + + +def local_checkpoint_selection(run_dir: Path, st_key: str) -> Path: + checkpoint_paths: List[Path] = list(run_dir.rglob("checkpoints/*")) + if len(checkpoint_paths) == 0: + st.error(f"There's no checkpoint under {run_dir}! Are you sure the restore was successful?") + st.stop() + checkpoint_path: Path = st_run_sel.selectbox( + label="Select a checkpoint", + index=0, + options=checkpoint_paths, + format_func=operator.attrgetter("name"), + key=f"checkpoint_select_{st_key}", + ) + + return checkpoint_path + + +def get_run_dir(entity: str, project: str, run_id: str) -> Path: + """Get run directory. + + :param run_path: "entity/project/run_id" + :return: + """ + api = wandb.Api() + run = api.run(path=f"{entity}/{project}/{run_id}") + created_at: datetime = datetime.datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S") + st.sidebar.markdown(body=f"[`Open on WandB`]({run.url})") + + timestamp: str = created_at.strftime("%Y%m%d_%H%M%S") + + matching_runs: List[Path] = [item for item in WANDB_DIR.iterdir() if item.is_dir() and item.name.endswith(run_id)] + + if len(matching_runs) > 1: + st.error(f"More than one run matching unique id {run_id}! Are you sure about that?") + st.stop() + + if len(matching_runs) == 1: + return matching_runs[0] + + only_checkpoint: bool = st_run_sel.checkbox(label="Download only the checkpoint?", value=True) + if st_run_sel.button(label="Download"): + run_dir: Path = WANDB_DIR / f"restored-{timestamp}-{run.id}" / "files" + files = [file for file in run.files() if "checkpoint" in file.name or not only_checkpoint] + if len(files) == 0: + st.error(f"There is no file to download from this run! Check on WandB: {run.url}") + for file in stqdm(files, desc="Downloading files..."): + file.download(root=run_dir) + return run_dir + else: + st.stop() + + +def select_run_path(st_key: str, default_run_path: str): + run_path: str = st_run_sel.text_input( + label="Run path (entity/project/id):", + value=default_run_path, + key=f"run_path_select_{st_key}", + ) + if not run_path: + st.stop() + tokens: List[str] = run_path.split("/") + if len(tokens) != 3: + st.error(f"This run path {run_path} doesn't look like a WandB run path! Are you sure about that?") + st.stop() + + return tokens + + +def select_checkpoint(st_key: str = "MyAwesomeModel", default_run_path: str = ""): + entity, project, run_id = select_run_path(st_key=st_key, default_run_path=default_run_path) + + run_dir: Path = get_run_dir(entity=entity, project=project, run_id=run_id) + + return local_checkpoint_selection(run_dir, st_key=st_key) + + +def get_hydra_cfg(config_name: str = "default") -> omegaconf.DictConfig: + """Instantiate and return the hydra config -- streamlit and jupyter compatible. + + Args: + config_name: .yaml configuration name, without the extension + + Returns: + The desired omegaconf.DictConfig + """ + GlobalHydra.instance().clear() + hydra.experimental.initialize_config_dir(config_dir=str(PROJECT_ROOT / "conf")) + return compose(config_name=config_name)