From d7a15a743646a24f1620001cf97d06cb743db536 Mon Sep 17 00:00:00 2001 From: Mayank Mittal Date: Mon, 11 Dec 2023 11:34:39 +0100 Subject: [PATCH 1/3] adds logging of git files for wandb and neptune --- rsl_rl/runners/on_policy_runner.py | 7 ++++++- rsl_rl/utils/neptune_utils.py | 4 ++++ rsl_rl/utils/utils.py | 10 ++++++++-- rsl_rl/utils/wandb_utils.py | 3 +++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index fbd992de..9e0a459c 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -156,7 +156,12 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals self.save(os.path.join(self.log_dir, f"model_{it}.pt")) ep_infos.clear() if it == start_iter: - store_code_state(self.log_dir, self.git_status_repos) + # obtain all the diff files + git_file_paths = store_code_state(self.log_dir, self.git_status_repos) + # if possible store them to wandb + if self.logger_type in ["wandb", "neptune"] and git_file_paths: + for path in git_file_paths: + self.writer.save_file(path) self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_utils.py index 78589889..2b3402df 100644 --- a/rsl_rl/utils/neptune_utils.py +++ b/rsl_rl/utils/neptune_utils.py @@ -86,3 +86,7 @@ def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): def save_model(self, model_path, iter): self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path) + + def save_file(self, path, iter=None): + name = path.rsplit("/", 1)[-1].split(".")[0] + self.neptune_logger.run["git_diff/" + name].upload(path) diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index 3c1dcae4..4a9abaad 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -58,15 +58,21 @@ def unpad_trajectories(trajectories, masks): ) -def store_code_state(logdir, repositories): +def store_code_state(logdir, repositories) -> list: + file_paths = [] for repository_file_path in repositories: try: repo = git.Repo(repository_file_path, search_parent_directories=True) except git.InvalidGitRepositoryError: # skip if not a git repository continue + # get the name of the repository repo_name = pathlib.Path(repo.working_dir).name t = repo.head.commit.tree + diff_file_name = os.path.join(logdir, f"{repo_name}_git.diff") content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" - with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x") as f: + with open(diff_file_name, "x") as f: f.write(content) + # add the file path to the list of files to be uploaded + file_paths.append(diff_file_name) + return file_paths diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 86b61cc7..630f8706 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -75,3 +75,6 @@ def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): def save_model(self, model_path, iter): wandb.save(model_path) + + def save_file(self, path, iter=None): + wandb.save(path) From cde1e87a19cf387d92f769f83a75d5b280c3d8f2 Mon Sep 17 00:00:00 2001 From: Mayank Mittal Date: Mon, 11 Dec 2023 11:50:05 +0100 Subject: [PATCH 2/3] fixes diff storage location --- rsl_rl/utils/utils.py | 4 +++- rsl_rl/utils/wandb_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index 4a9abaad..88dd7411 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -59,6 +59,8 @@ def unpad_trajectories(trajectories, masks): def store_code_state(logdir, repositories) -> list: + git_log_dir = os.path.join(logdir, "git") + os.makedirs(git_log_dir, exist_ok=True) file_paths = [] for repository_file_path in repositories: try: @@ -69,7 +71,7 @@ def store_code_state(logdir, repositories) -> list: # get the name of the repository repo_name = pathlib.Path(repo.working_dir).name t = repo.head.commit.tree - diff_file_name = os.path.join(logdir, f"{repo_name}_git.diff") + diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff") content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" with open(diff_file_name, "x") as f: f.write(content) diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 630f8706..2868ce91 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -74,7 +74,7 @@ def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg) def save_model(self, model_path, iter): - wandb.save(model_path) + wandb.save(model_path, base_path=os.path.dirname(model_path)) def save_file(self, path, iter=None): - wandb.save(path) + wandb.save(path, base_path=os.path.dirname(path)) From 0dc9544952e69c4b90e8a97fd25c24e1d53d27f3 Mon Sep 17 00:00:00 2001 From: Mayank Mittal Date: Mon, 11 Dec 2023 12:02:31 +0100 Subject: [PATCH 3/3] updates to new neptune --- rsl_rl/utils/neptune_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_utils.py index 2b3402df..f06cc625 100644 --- a/rsl_rl/utils/neptune_utils.py +++ b/rsl_rl/utils/neptune_utils.py @@ -8,14 +8,14 @@ from torch.utils.tensorboard import SummaryWriter try: - import neptune.new as neptune + import neptune except ModuleNotFoundError: raise ModuleNotFoundError("neptune-client is required to log to Neptune.") class NeptuneLogger: def __init__(self, project, token): - self.run = neptune.init(project=project, api_token=token) + self.run = neptune.init_run(project=project, api_token=token) def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): self.run["runner_cfg"] = runner_cfg