Skip to content

Commit

Permalink
Merge branch 'staging'
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Nov 10, 2022
2 parents b27331c + e3b4576 commit 2da1f67
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 26 deletions.
49 changes: 35 additions & 14 deletions lib/cli/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import argparse
import os
from typing import Any, List, Optional, Tuple, Union


# << FILE HANDLING >>
Expand All @@ -18,7 +19,7 @@ class _FullPaths(argparse.Action): # pylint: disable=too-few-public-methods
called directly. It is the base class for the various different file handling
methods.
"""
def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser, namespace, values, option_string=None) -> None:
if isinstance(values, (list, tuple)):
vals = [os.path.abspath(os.path.expanduser(val)) for val in values]
else:
Expand Down Expand Up @@ -68,7 +69,7 @@ class FileFullPaths(_FullPaths):
>>> filetypes="video))"
"""
# pylint: disable=too-few-public-methods
def __init__(self, *args, filetypes=None, **kwargs):
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.filetypes = filetypes

Expand Down Expand Up @@ -110,7 +111,7 @@ class FilesFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods
>>> filetypes="image",
>>> nargs="+"))
"""
def __init__(self, *args, filetypes=None, **kwargs):
def __init__(self, *args, filetypes: Optional[str] = None, **kwargs) -> None:
if kwargs.get("nargs", None) is None:
opt = kwargs["option_strings"]
raise ValueError(f"nargs must be provided for FilesFullPaths: {opt}")
Expand Down Expand Up @@ -144,7 +145,6 @@ class DirOrFileFullPaths(FileFullPaths): # pylint: disable=too-few-public-metho
>>> action=DirOrFileFullPaths,
>>> filetypes="video))"
"""
pass # pylint: disable=unnecessary-pass


class DirOrFilesFullPaths(FileFullPaths): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -175,7 +175,20 @@ class DirOrFilesFullPaths(FileFullPaths): # pylint: disable=too-few-public-meth
>>> action=DirOrFileFullPaths,
>>> filetypes="video))"
"""
pass # pylint: disable=unnecessary-pass
def __call__(self, parser, namespace, values, option_string=None) -> None:
""" Override :class:`_FullPaths` __call__ function.
The input for this option can be a space separated list of files or a single folder.
Folders can have spaces in them, so we don't want to blindly expand the paths.
We check whether the input can be resolved to a folder first before expanding.
"""
assert isinstance(values, (list, tuple))
folder = os.path.abspath(os.path.expanduser(" ".join(values)))
if os.path.isdir(folder):
setattr(namespace, self.dest, [folder])
else: # file list so call parent method
super().__call__(parser, namespace, values, option_string)


class SaveFileFullPaths(FileFullPaths):
Expand Down Expand Up @@ -235,7 +248,11 @@ class ContextFullPaths(FileFullPaths):
>>> action_option="-a"))
"""
# pylint: disable=too-few-public-methods, too-many-arguments
def __init__(self, *args, filetypes=None, action_option=None, **kwargs):
def __init__(self,
*args,
filetypes: Optional[str] = None,
action_option: Optional[str] = None,
**kwargs) -> None:
opt = kwargs["option_strings"]
if kwargs.get("nargs", None) is not None:
raise ValueError(f"nargs not allowed for ContextFullPaths: {opt}")
Expand All @@ -246,7 +263,7 @@ def __init__(self, *args, filetypes=None, action_option=None, **kwargs):
super().__init__(*args, filetypes=filetypes, **kwargs)
self.action_option = action_option

def _get_kwargs(self):
def _get_kwargs(self) -> List[Tuple[str, Any]]:
names = ["option_strings",
"dest",
"nargs",
Expand Down Expand Up @@ -280,15 +297,15 @@ class Radio(argparse.Action): # pylint: disable=too-few-public-methods
>>> action=Radio,
>>> choices=["foo", "bar"))
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
opt = kwargs["option_strings"]
if kwargs.get("nargs", None) is not None:
raise ValueError(f"nargs not allowed for Radio buttons: {opt}")
if not kwargs.get("choices", []):
raise ValueError(f"Choices must be provided for Radio buttons: {opt}")
super().__init__(*args, **kwargs)

def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser, namespace, values, option_string=None) -> None:
setattr(namespace, self.dest, values)


Expand All @@ -308,15 +325,15 @@ class MultiOption(argparse.Action): # pylint: disable=too-few-public-methods
>>> action=MultiOption,
>>> choices=["foo", "bar"))
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
opt = kwargs["option_strings"]
if not kwargs.get("nargs", []):
raise ValueError(f"nargs must be provided for MultiOption: {opt}")
if not kwargs.get("choices", []):
raise ValueError(f"Choices must be provided for MultiOption: {opt}")
super().__init__(*args, **kwargs)

def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser, namespace, values, option_string=None) -> None:
setattr(namespace, self.dest, values)


Expand Down Expand Up @@ -363,7 +380,11 @@ class Slider(argparse.Action): # pylint: disable=too-few-public-methods
>>> type=float,
>>> default=5.00))
"""
def __init__(self, *args, min_max=None, rounding=None, **kwargs):
def __init__(self,
*args,
min_max: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
rounding: Optional[int] = None,
**kwargs) -> None:
opt = kwargs["option_strings"]
if kwargs.get("nargs", None) is not None:
raise ValueError(f"nargs not allowed for Slider: {opt}")
Expand All @@ -380,7 +401,7 @@ def __init__(self, *args, min_max=None, rounding=None, **kwargs):
self.min_max = min_max
self.rounding = rounding

def _get_kwargs(self):
def _get_kwargs(self) -> List[Tuple[str, Any]]:
names = ["option_strings",
"dest",
"nargs",
Expand All @@ -394,5 +415,5 @@ def _get_kwargs(self):
"rounding"] # Decimal places to round floats to or step interval for ints
return [(name, getattr(self, name)) for name in names]

def __call__(self, parser, namespace, values, option_string=None):
def __call__(self, parser, namespace, values, option_string=None) -> None:
setattr(namespace, self.dest, values)
44 changes: 32 additions & 12 deletions scripts/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,37 @@ def n_embeddings(self) -> np.ndarray:
return retval

@classmethod
def _validate_inputs(cls,
def _files_from_folder(cls, input_location: List[str]) -> List[str]:
""" Test whether the input location is a folder and if so, return the list of contained
image files, otherwise return the original input location
Parameters
---------
input_files: list
A list of full paths to individual files or to a folder location
Returns
-------
bool
Either the original list of files provided, or the image files that exist in the
provided folder location
"""
if not input_location or len(input_location) > 1:
return input_location

test_folder = input_location[0]
if not os.path.isdir(test_folder):
logger.debug("'%s' is not a folder. Returning original list", test_folder)
return input_location

retval = [os.path.join(test_folder, fname)
for fname in os.listdir(test_folder)
if os.path.splitext(fname)[-1].lower() in _image_extensions]
logger.info("Collected files from folder '%s': %s", test_folder,
[os.path.basename(f) for f in retval])
return retval

def _validate_inputs(self,
filter_files: Optional[List[str]],
nfilter_files: Optional[List[str]]) -> Tuple[List[str], List[str]]:
""" Validates that the given filter/nfilter files exist, are image files and are unique
Expand All @@ -252,17 +282,7 @@ def _validate_inputs(cls,
retval: List[List[str]] = []

for files in (filter_files, nfilter_files):

if isinstance(files, list) and len(files) == 1 and os.path.isdir(files[0]):
# Get images from folder, if folder passed in
dirname = files[0]
files = [os.path.join(dirname, fname)
for fname in os.listdir(dirname)
if os.path.splitext(fname)[-1].lower() in _image_extensions]
logger.debug("Collected files from folder '%s': %s", dirname,
[os.path.basename(f) for f in files])

filt_files = [] if files is None else files
filt_files = [] if files is None else self._files_from_folder(files)
for file in filt_files:
if (not os.path.isfile(file) or
os.path.splitext(file)[-1].lower() not in _image_extensions):
Expand Down

0 comments on commit 2da1f67

Please sign in to comment.