|
28 | 28 | # |
29 | 29 | """Validate that the commit message is ok.""" |
30 | 30 |
|
31 | | -import argparse |
32 | 31 | import re |
33 | | -import sys |
34 | | -import logging |
| 32 | +import pathlib |
| 33 | +import structlog |
| 34 | +import subprocess |
| 35 | +import typer |
| 36 | +from typing_extensions import Annotated |
| 37 | +from git import Commit, Repo |
35 | 38 |
|
36 | | -LOGGER = logging.getLogger(__name__) |
| 39 | +LOGGER = structlog.get_logger(__name__) |
37 | 40 |
|
38 | 41 | STATUS_OK = 0 |
39 | 42 | STATUS_ERROR = 1 |
40 | 43 |
|
| 44 | +repo_root = pathlib.Path( |
| 45 | + subprocess.run( |
| 46 | + "git rev-parse --show-toplevel", shell=True, text=True, capture_output=True |
| 47 | + ).stdout.strip() |
| 48 | +) |
41 | 49 |
|
42 | | -def main(argv=None): |
43 | | - """Execute Main function to validate commit messages.""" |
44 | | - parser = argparse.ArgumentParser( |
45 | | - usage="Validate the commit message. " |
46 | | - "It validates the latest message when no arguments are provided." |
47 | | - ) |
48 | | - parser.add_argument( |
49 | | - "message", |
50 | | - metavar="commit message", |
51 | | - nargs="*", |
52 | | - help="The commit message to validate", |
53 | | - ) |
54 | | - args = parser.parse_args(argv) |
| 50 | +pr_template = "" |
| 51 | +with open(repo_root / ".github" / "pull_request_template.md", "r") as r: |
| 52 | + pr_template = r.read().strip() |
| 53 | + |
| 54 | +BANNED_STRINGS = ["https://spruce.mongodb.com", "https://evergreen.mongodb.com", pr_template] |
| 55 | + |
| 56 | +VALID_SUMMARY = re.compile(r'(Revert ")?(SERVER-[0-9]+|Import wiredtiger)') |
55 | 57 |
|
56 | | - if not args.message: |
57 | | - LOGGER.error("Must specify non-empty value for --message") |
58 | | - return STATUS_ERROR |
59 | | - message = " ".join(args.message) |
60 | 58 |
|
| 59 | +def is_valid_commit(commit: Commit) -> bool: |
61 | 60 | # Valid values look like: |
62 | 61 | # 1. SERVER-\d+ |
63 | 62 | # 2. Revert "SERVER-\d+ |
64 | 63 | # 3. Import wiredtiger |
65 | 64 | # 4. Revert "Import wiredtiger |
66 | | - valid_pattern = re.compile(r'(Revert ")?(SERVER-[0-9]+|Import wiredtiger)') |
| 65 | + if not VALID_SUMMARY.match(commit.summary): |
| 66 | + LOGGER.error( |
| 67 | + "Commit did not contain a valid summary", |
| 68 | + commit_hexsha=commit.hexsha, |
| 69 | + commit_summary=commit.summary, |
| 70 | + ) |
| 71 | + return False |
| 72 | + |
| 73 | + for banned_string in BANNED_STRINGS: |
| 74 | + if banned_string in commit.message: |
| 75 | + LOGGER.error( |
| 76 | + "Commit contains banned string", |
| 77 | + banned_string=banned_string, |
| 78 | + commit_hexsha=commit.hexsha, |
| 79 | + commit_message=commit.message, |
| 80 | + ) |
| 81 | + return False |
| 82 | + |
| 83 | + return True |
| 84 | + |
| 85 | + |
| 86 | +def main( |
| 87 | + branch_name: Annotated[ |
| 88 | + str, |
| 89 | + typer.Option(envvar="BRANCH_NAME", help="Name of the branch to compare against HEAD"), |
| 90 | + ], |
| 91 | + is_commit_queue: Annotated[ |
| 92 | + str, |
| 93 | + typer.Option( |
| 94 | + envvar="IS_COMMIT_QUEUE", |
| 95 | + help="If this is being run in the commit/merge queue. Set to anything to be considered part of the commit/merge queue.", |
| 96 | + ), |
| 97 | + ] = "", |
| 98 | +): |
| 99 | + """ |
| 100 | + Validate the commit message. |
| 101 | +
|
| 102 | + It validates the latest message when no arguments are provided. |
| 103 | + """ |
| 104 | + |
| 105 | + if not is_commit_queue: |
| 106 | + LOGGER.info("Exiting early since this is not running in the commit/merge queue") |
| 107 | + raise typer.Exit(code=STATUS_OK) |
| 108 | + |
| 109 | + diff_commits = subprocess.run( |
| 110 | + ["git", "log", '--pretty=format:"%H"', f"{branch_name}...HEAD"], |
| 111 | + check=True, |
| 112 | + capture_output=True, |
| 113 | + text=True, |
| 114 | + ) |
| 115 | + # Comes back like "hash1"\n"hash2"\n... |
| 116 | + commit_hashs: list[str] = diff_commits.stdout.replace('"', "").splitlines() |
| 117 | + LOGGER.info("Diff commit hashes", commit_hashs=commit_hashs) |
| 118 | + repo = Repo(repo_root) |
| 119 | + |
| 120 | + for commit_hash in commit_hashs: |
| 121 | + commit = repo.commit(commit_hash) |
| 122 | + if not is_valid_commit(commit): |
| 123 | + LOGGER.error("Found an invalid commit", commit=commit) |
| 124 | + raise typer.Exit(code=STATUS_ERROR) |
67 | 125 |
|
68 | | - if valid_pattern.match(message): |
69 | | - return STATUS_OK |
70 | | - else: |
71 | | - LOGGER.error(f"Found a commit without a ticket\n{message}") # pylint: disable=logging-fstring-interpolation |
72 | | - return STATUS_ERROR |
| 126 | + return |
73 | 127 |
|
74 | 128 |
|
75 | 129 | if __name__ == "__main__": |
76 | | - sys.exit(main(sys.argv[1:])) |
| 130 | + typer.run(main) |
0 commit comments