Skip to content

Commit

Permalink
Add case case_sensitive in scandir (#1389)
Browse files Browse the repository at this point in the history
* add case_insensitive

* rename v

* case_insensitive to case_sensitive

* Update docstring
  • Loading branch information
Ezra-Yu committed Oct 20, 2021
1 parent c85c240 commit e8489a7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
19 changes: 13 additions & 6 deletions mmcv/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def symlink(src, dst, overwrite=True, **kwargs):
os.symlink(src, dst, **kwargs)


def scandir(dir_path, suffix=None, recursive=False):
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
"""Scan a directory to find the interested files.
Args:
Expand All @@ -45,6 +45,8 @@ def scandir(dir_path, suffix=None, recursive=False):
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
case_sensitive (bool, optional) : If set to False, ignore the case of
suffix. Default: True.
Returns:
A generator for all the interested files with relative paths.
Expand All @@ -57,20 +59,25 @@ def scandir(dir_path, suffix=None, recursive=False):
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')

if suffix is not None and not case_sensitive:
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
item.lower() for item in suffix)

root = dir_path

def _scandir(dir_path, suffix, recursive):
def _scandir(dir_path, suffix, recursive, case_sensitive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = osp.relpath(entry.path, root)
if suffix is None or rel_path.endswith(suffix):
_rel_path = rel_path if case_sensitive else rel_path.lower()
if suffix is None or _rel_path.endswith(suffix):
yield rel_path
elif recursive and os.path.isdir(entry.path):
# scan recursively if entry.path is a directory
yield from _scandir(
entry.path, suffix=suffix, recursive=recursive)
yield from _scandir(entry.path, suffix, recursive,
case_sensitive)

return _scandir(dir_path, suffix=suffix, recursive=recursive)
return _scandir(dir_path, suffix, recursive, case_sensitive)


def find_vcs_root(path, markers=('.git', )):
Expand Down
Empty file added tests/data/for_scan/3.TXT
Empty file.
17 changes: 15 additions & 2 deletions tests/test_utils/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_check_file_exist():

def test_scandir():
folder = osp.join(osp.dirname(osp.dirname(__file__)), 'data/for_scan')
filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json']
filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT']
assert set(mmcv.scandir(folder)) == set(filenames)
assert set(mmcv.scandir(Path(folder))) == set(filenames)
assert set(mmcv.scandir(folder, '.txt')) == set(
Expand All @@ -41,7 +41,7 @@ def test_scandir():
# path of sep is `\\` in windows but `/` in linux, so osp.join should be
# used to join string for compatibility
filenames_recursive = [
'a.bin', '1.txt', '2.txt', '1.json', '2.json',
'a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT',
osp.join('sub', '1.json'),
osp.join('sub', '1.txt'), '.file'
]
Expand All @@ -54,6 +54,19 @@ def test_scandir():
filename for filename in filenames_recursive
if filename.endswith('.txt')
])
assert set(
mmcv.scandir(folder, '.TXT', recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
if filename.endswith(('.txt', '.TXT'))
])
assert set(
mmcv.scandir(
folder, ('.TXT', '.JSON'), recursive=True,
case_sensitive=False)) == set([
filename for filename in filenames_recursive
if filename.endswith(('.txt', '.json', '.TXT'))
])
with pytest.raises(TypeError):
list(mmcv.scandir(123))
with pytest.raises(TypeError):
Expand Down

0 comments on commit e8489a7

Please sign in to comment.