Skip to content

Commit

Permalink
feat: add git_utils and model_utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
gao-hongnan committed Jun 17, 2024
1 parent 4cd578d commit de2690f
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 0 deletions.
154 changes: 154 additions & 0 deletions omnivault/utils/reproducibility/git_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Core functions for Git versioning."""
from __future__ import annotations

import logging
import subprocess
from typing import List

logger = logging.getLogger(__name__)

__all__ = ["check_git_status", "get_git_commit_hash", "commit_changes", "push_changes"]

def check_git_status(working_dir: str | None = None) -> bool:
"""
Check the Git status of the working directory. If there are untracked or
uncommitted changes, return False.
Parameters
----------
working_dir : str, optional
The path of the working directory where the Git command should be executed,
by default None. If None, it uses the current working directory.
Returns
-------
bool
True if there are no untracked or uncommitted changes in the working directory,
False otherwise.
"""
status_output = subprocess.check_output(["git", "status", "--porcelain"], cwd=working_dir).decode("utf-8").strip()
return len(status_output) == 0


def log_message_if_working_dir_is_none(working_dir: str | None = None) -> None:
"""
Log an informative message if no working directory is provided.
Parameters
----------
working_dir : str, optional
The path of the working directory, by default None.
If None, an info message is logged.
"""
if working_dir is None:
logger.info("Working directory not provided. Defaulting to current directory.")
logger.info(f"Working directory: {working_dir}")


def get_git_commit_hash(working_dir: str | None = None, are_there_untracked_or_uncommitted: bool = False) -> str:
"""
Get the current Git commit hash.
If Git is not installed or the working directory is not a Git repository,
the function returns "N/A".
Parameters
----------
working_dir : str, optional
The path of the working directory where the Git command should be executed,
by default None. If None, it uses the current working directory ".".
are_there_untracked_or_uncommitted : Literal[True, False], optional
Whether to check if there are untracked or uncommitted changes in the
working directory, by default False.
Returns
-------
commit_hash : str
The Git commit hash, or "N/A" if Git is not installed or the working
directory is not a Git repository.
"""
log_message_if_working_dir_is_none(working_dir)

git_command = ["git", "rev-parse", "HEAD"]

try:
if are_there_untracked_or_uncommitted and not check_git_status(working_dir):
error_message = (
"There are untracked or uncommitted files in the working directory. "
"Please commit or stash them before running training as the commit hash "
"will be used to tag the model."
)
raise RuntimeError(error_message)

commit_hash = subprocess.check_output(git_command, cwd=working_dir).decode("utf-8").strip()
except FileNotFoundError:
logger.exception("Git not found or the provided working directory doesn't exist.")
commit_hash = "N/A"
except subprocess.CalledProcessError:
logger.exception("The provided directory is not a Git repository.")
commit_hash = "N/A"

return commit_hash


def commit_changes(
commit_message: str,
file_paths: List[str] | None = None,
working_dir: str | None = None,
) -> None:
"""
Commit changes to a Git repository.
Parameters
----------
commit_message : str
The message to use for the commit.
file_paths : Optional[List[str]], default=None
List of file paths to be added to the commit. If not specified,
all changes will be committed.
working_dir : str | None, default=None
The path of the working directory where the Git commands should be executed.
If None, the commands will be executed in the current working directory.
"""
log_message_if_working_dir_is_none(working_dir)

try:
# Add files to the staging area
if file_paths is None:
subprocess.run(["git", "add", "."], check=True, cwd=working_dir)
else:
for file_path in file_paths:
subprocess.run(["git", "add", file_path], check=True, cwd=working_dir)

# Commit the changes
subprocess.run(["git", "commit", "-m", commit_message], check=True, cwd=working_dir)
except subprocess.CalledProcessError as error:
print(f"An error occurred while committing changes: {error}")


def push_changes(
remote_name: str = "origin",
branch_name: str = "master",
working_dir: str | None = None,
) -> None:
"""
Push committed changes to a remote repository.
Parameters
----------
remote_name : str, default='origin'
The name of the remote repository to push to.
branch_name : str, default='master'
The name of the branch to push.
working_dir : str, optional
The path of the working directory where the Git command should be executed,
by default None. If None, it uses the current working directory.
"""

log_message_if_working_dir_is_none(working_dir)

try:
# Push the changes
subprocess.run(["git", "push", remote_name, branch_name], check=True, cwd=working_dir)
except subprocess.CalledProcessError as error:
print(f"An error occurred while pushing changes: {error}")
78 changes: 78 additions & 0 deletions omnivault/utils/torch_utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Any, Tuple

import torch
from torch import nn

__all__ = ["total_trainable_parameters", "total_parameters", "compare_models"]


def total_trainable_parameters(module: nn.Module) -> int:
"""Returns the number of trainable parameters in the model."""
return sum(p.numel() for p in module.parameters() if p.requires_grad)


def total_parameters(module: nn.Module) -> int:
"""Returns the total number of parameters in the model, including non-trainable."""
return sum(p.numel() for p in module.parameters())


def compare_models(model_a: nn.Module, model_b: nn.Module) -> bool:
"""
Compare two PyTorch models to check if they have identical parameters.
Parameters
----------
model_a : nn.Module
The first model to compare.
model_b : nn.Module
The second model to compare.
Returns
-------
bool
Returns True if both models have identical parameters, False otherwise.
"""
return all(
torch.equal(param_a[1], param_b[1])
for param_a, param_b in zip(model_a.state_dict().items(), model_b.state_dict().items())
)


def compare_models_and_report_differences(model_a: nn.Module, model_b: nn.Module) -> Tuple[bool, Any]:
"""
Compare two PyTorch models to check if they have identical parameters.
Parameters
----------
model_a : nn.Module
The first model to compare.
model_b : nn.Module
The second model to compare.
Returns
-------
Tuple[bool, Any]
Returns a tuple with the first element as a boolean indicating if the models are identical.
The second element is a dictionary containing the differences between the models if they are not identical.
"""
model_a_dict = model_a.state_dict()
model_b_dict = model_b.state_dict()

if set(model_a_dict.keys()) != set(model_b_dict.keys()):
# Early exit if model architectures are different (different sets of parameter keys)
return False, {"error": "Models have different architectures and cannot be compared."}

differences = {}
for name in model_a_dict.keys(): # noqa: SIM118
param_a = model_a_dict[name]
param_b = model_b_dict[name]
if not torch.equal(param_a, param_b):
differences[name] = {
"model_a": param_a.detach().cpu().numpy(),
"model_b": param_b.detach().cpu().numpy(),
}

if differences:
return False, differences
else:
return True, None

0 comments on commit de2690f

Please sign in to comment.