Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/nn_core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ def _is_nnlogger(trainer: Trainer) -> bool:

def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if self._is_nnlogger(trainer):
trainer.logger: NNLogger
trainer.logger.upload_source()
trainer.logger.log_configuration(model=pl_module)
trainer.logger.watch_model(pl_module=pl_module)

def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if self._is_nnlogger(trainer):
trainer.logger: NNLogger
trainer.logger.upload_run_files()

def on_save_checkpoint(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
) -> None:
Expand Down
45 changes: 41 additions & 4 deletions src/nn_core/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import logging
import os
from typing import Optional
from typing import List, Optional

import dotenv
import numpy as np
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from pytorch_lightning import seed_everything
from rich.prompt import Prompt

logger = logging.getLogger(__name__)
pylogger = logging.getLogger(__name__)


def get_env(env_name: str, default: Optional[str] = None) -> str:
Expand All @@ -19,13 +24,17 @@ def get_env(env_name: str, default: Optional[str] = None) -> str:
"""
if env_name not in os.environ:
if default is None:
raise KeyError(f"{env_name} not defined and no default value is present!")
message = f"{env_name} not defined and no default value is present!"
pylogger.error(message)
raise KeyError(message)
return default

env_value: str = os.environ[env_name]
if not env_value:
if default is None:
raise ValueError(f"{env_name} has yet to be configured and no default value is present!")
message = f"{env_name} has yet to be configured and no default value is present!"
pylogger.error(message)
raise ValueError(message)
return default

return env_value
Expand All @@ -42,3 +51,31 @@ def load_envs(env_file: Optional[str] = None) -> None:
it searches for a `.env` file in the project.
"""
dotenv.load_dotenv(dotenv_path=env_file, override=True)


def enforce_tags(tags: Optional[List[str]]) -> List[str]:
if tags is None:
if "id" in HydraConfig().cfg.hydra.job:
# We are in multi-run setting (either via a sweep or a scheduler)
message: str = "You need to specify 'core.tags' in a multi-run setting!"
pylogger.error(message)
raise ValueError(message)

pylogger.warning("No tags provided, asking for tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="develop")
tags = [x.strip() for x in tags.split(",")]

pylogger.info(f"Tags: {tags if tags is not None else []}")
return tags


def seed_index_everything(train_cfg: DictConfig) -> None:
if "seed_index" in train_cfg and train_cfg.seed_index is not None:
seed_index = train_cfg.seed_index
seed_everything(42)
seeds = np.random.randint(np.iinfo(np.int32).max, size=max(42, seed_index + 1))
seed = seeds[seed_index]
seed_everything(seed)
pylogger.info(f"Setting seed {seed} from seeds[{seed_index}]")
else:
pylogger.warning("The seed has not been set! The reproducibility is not guaranteed.")
12 changes: 11 additions & 1 deletion src/nn_core/model_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import os
import shutil
from pathlib import Path
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -55,7 +56,7 @@ def watch_model(self, pl_module: LightningModule):

def upload_source(self) -> None:
if self.logging_cfg.upload.source and self.wandb:
pylogger.info("Uploading source code to wandb")
pylogger.info("Uploading source code to W&B")
self.wrapped.experiment.log_code(
root=PROJECT_ROOT,
name=None,
Expand Down Expand Up @@ -201,3 +202,12 @@ def log_configuration(
# send hparams to all loggers
pylogger.debug("Logging 'cfg'")
self.wrapped.log_hyperparams(cfg)

def upload_run_files(self):
if self.logging_cfg.upload.run_files:
if self.wandb:
pylogger.info("Uploading run files to W&B")
shutil.copytree(self.run_dir, f"{self.wrapped.experiment.dir}/run_files")

# FIXME: symlink not working for some reason
# os.symlink(self.run_dir, f"{self.wrapped.experiment.dir}/run_files", target_is_directory=True)
54 changes: 53 additions & 1 deletion src/nn_core/resume.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
import logging
import re
from operator import xor
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple

import torch
import wandb
from omegaconf import DictConfig
from wandb.apis.public import Run

pylogger = logging.getLogger(__name__)

RUN_PATH_PATTERN = re.compile(r"^([^/]+)/([^/]+)/([^/]+)$")

RESUME_MODES = {
"continue": {
"restore_model": True,
"restore_run": True,
},
"hotstart": {
"restore_model": True,
"restore_run": False,
},
None: {
"restore_model": False,
"restore_run": False,
},
}


def resolve_ckpt(ckpt_or_run_path: str) -> str:
"""Resolve the run path or ckpt to a checkpoint.
Expand Down Expand Up @@ -61,3 +81,35 @@ def resolve_run_version(ckpt_or_run_path: Optional[str] = None, run_path: Option
if run_path is None:
run_path = resolve_run_path(ckpt_or_run_path)
return RUN_PATH_PATTERN.match(run_path).group(3)


def parse_restore(restore_cfg: DictConfig) -> Tuple[Optional[str], Optional[str]]:
ckpt_or_run_path = restore_cfg.ckpt_or_run_path
resume_mode = restore_cfg.mode

resume_ckpt_path = None
resume_run_version = None

if xor(bool(ckpt_or_run_path), bool(resume_mode)):
pylogger.warning(f"Inconsistent resume modality {resume_mode} and checkpoint path '{ckpt_or_run_path}'")

if resume_mode not in RESUME_MODES:
message = f"Unsupported resume mode {resume_mode}. Available resume modes are: {RESUME_MODES}"
pylogger.error(message)
raise ValueError(message)

flags = RESUME_MODES[resume_mode]
restore_model = flags["restore_model"]
restore_run = flags["restore_run"]

if ckpt_or_run_path is not None:
if restore_model:
resume_ckpt_path = resolve_ckpt(ckpt_or_run_path)
pylogger.info(f"Resume training from: '{resume_ckpt_path}'")

if restore_run:
run_path = resolve_run_path(ckpt_or_run_path)
resume_run_version = resolve_run_version(run_path=run_path)
pylogger.info(f"Resume logging to: '{run_path}'")

return resume_ckpt_path, resume_run_version
108 changes: 108 additions & 0 deletions src/nn_core/ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import datetime
import operator
from pathlib import Path
from typing import List

import hydra
import omegaconf
import streamlit as st
import wandb
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose
from stqdm import stqdm

from nn_core.common import PROJECT_ROOT

WANDB_DIR: Path = PROJECT_ROOT / "wandb"
WANDB_DIR.mkdir(exist_ok=True, parents=True)

st_run_sel = st.sidebar


def local_checkpoint_selection(run_dir: Path, st_key: str) -> Path:
checkpoint_paths: List[Path] = list(run_dir.rglob("checkpoints/*"))
if len(checkpoint_paths) == 0:
st.error(f"There's no checkpoint under {run_dir}! Are you sure the restore was successful?")
st.stop()
checkpoint_path: Path = st_run_sel.selectbox(
label="Select a checkpoint",
index=0,
options=checkpoint_paths,
format_func=operator.attrgetter("name"),
key=f"checkpoint_select_{st_key}",
)

return checkpoint_path


def get_run_dir(entity: str, project: str, run_id: str) -> Path:
"""Get run directory.

:param run_path: "entity/project/run_id"
:return:
"""
api = wandb.Api()
run = api.run(path=f"{entity}/{project}/{run_id}")
created_at: datetime = datetime.datetime.strptime(run.created_at, "%Y-%m-%dT%H:%M:%S")
st.sidebar.markdown(body=f"[`Open on WandB`]({run.url})")

timestamp: str = created_at.strftime("%Y%m%d_%H%M%S")

matching_runs: List[Path] = [item for item in WANDB_DIR.iterdir() if item.is_dir() and item.name.endswith(run_id)]

if len(matching_runs) > 1:
st.error(f"More than one run matching unique id {run_id}! Are you sure about that?")
st.stop()

if len(matching_runs) == 1:
return matching_runs[0]

only_checkpoint: bool = st_run_sel.checkbox(label="Download only the checkpoint?", value=True)
if st_run_sel.button(label="Download"):
run_dir: Path = WANDB_DIR / f"restored-{timestamp}-{run.id}" / "files"
files = [file for file in run.files() if "checkpoint" in file.name or not only_checkpoint]
if len(files) == 0:
st.error(f"There is no file to download from this run! Check on WandB: {run.url}")
for file in stqdm(files, desc="Downloading files..."):
file.download(root=run_dir)
return run_dir
else:
st.stop()


def select_run_path(st_key: str, default_run_path: str):
run_path: str = st_run_sel.text_input(
label="Run path (entity/project/id):",
value=default_run_path,
key=f"run_path_select_{st_key}",
)
if not run_path:
st.stop()
tokens: List[str] = run_path.split("/")
if len(tokens) != 3:
st.error(f"This run path {run_path} doesn't look like a WandB run path! Are you sure about that?")
st.stop()

return tokens


def select_checkpoint(st_key: str = "MyAwesomeModel", default_run_path: str = ""):
entity, project, run_id = select_run_path(st_key=st_key, default_run_path=default_run_path)

run_dir: Path = get_run_dir(entity=entity, project=project, run_id=run_id)

return local_checkpoint_selection(run_dir, st_key=st_key)


def get_hydra_cfg(config_name: str = "default") -> omegaconf.DictConfig:
"""Instantiate and return the hydra config -- streamlit and jupyter compatible.

Args:
config_name: .yaml configuration name, without the extension

Returns:
The desired omegaconf.DictConfig
"""
GlobalHydra.instance().clear()
hydra.experimental.initialize_config_dir(config_dir=str(PROJECT_ROOT / "conf"))
return compose(config_name=config_name)