diff --git a/Makefile b/Makefile index 93f47fdbb6c475..4ef8b924ef0f8a 100644 --- a/Makefile +++ b/Makefile @@ -3,14 +3,9 @@ check_dirs := examples tests src utils -# get modified files since the branch was made -fork_point_sha := $(shell git merge-base --fork-point master) -joined_dirs := $(shell echo $(check_dirs) | tr " " "|") -modified_py_files := $(shell git diff --name-only $(fork_point_sha) | egrep '^($(joined_dirs))' | egrep '\.py$$') -#$(info modified files are: $(modified_py_files)) - modified_only_fixup: - @if [ -n "$(modified_py_files)" ]; then \ + $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) + @if test -n "$(modified_py_files)"; then \ echo "Checking/fixing $(modified_py_files)"; \ black $(modified_py_files); \ isort $(modified_py_files); \ diff --git a/utils/get_modified_files.py b/utils/get_modified_files.py new file mode 100644 index 00000000000000..78d2ec128bf051 --- /dev/null +++ b/utils/get_modified_files.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.: +# python ./utils/get_modified_files.py utils src tests examples +# +# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered +# since the output of this script is fed into Makefile commands it doesn't print a newline after the results + +import re +import subprocess +import sys + + +fork_point_sha = subprocess.check_output("git merge-base --fork-point master".split()).decode("utf-8") +modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split() + +joined_dirs = "|".join(sys.argv[1:]) +regex = re.compile(fr"^({joined_dirs}).*?\.py$") + +relevant_modified_files = [x for x in modified_files if regex.match(x)] +print(" ".join(relevant_modified_files), end="")