From ccd144e1a768a0afc158588245dde89529d71886 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 6 Sep 2023 11:36:08 -0400 Subject: [PATCH] misc. docstring and type hint fixes [skip ci] --- .../pytorch_learner/learner_config.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index 2c16e66de1..8ca4f4856f 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -727,9 +727,20 @@ def get_bbox_params(self) -> Optional[A.BboxParams]: def get_data_transforms(self) -> Tuple[A.BasicTransform, A.BasicTransform]: """Get albumentations transform objects for data augmentation. + Returns a 2-tuple of a "base" transform and an augmentation transform. + The base transform comprises a resize transform based on img_sz + followed by the transform specified in base_transform. The augmentation + transform comprises the base transform followed by either the transform + in aug_transform (if specified) or the transforms in the augmentors + field. + + The augmentation transform is intended to be used for training data, + and the base transform for all other data where data augmentation is + not desirable, such as validation or prediction. + Returns: - 1st tuple arg: a transform that doesn't do any data augmentation - 2nd tuple arg: a transform with data augmentation + Tuple[A.BasicTransform, A.BasicTransform]: base transform and + augmentation transform. """ bbox_params = self.get_bbox_params() base_tfs = [A.Resize(self.img_sz, self.img_sz)] @@ -1013,7 +1024,7 @@ def get_data_dirs(self, uri: Union[str, List[str]], (optinally) "test" subdirectories. Args: - uri (Union[str, List[str]]): a URI or a list of URIs of one of the + uri (Union[str, List[str]]): A URI or a list of URIs of one of the following: (1) a URI of a directory containing "train", "valid", and @@ -1021,9 +1032,12 @@ def get_data_dirs(self, uri: Union[str, List[str]], (2) a URI of a zip file containing (1) (3) a list of (2) (4) a URI of a directory containing zip files containing (1) + unzip_dir (str): Directory where zip files will be extrated to, if + needed. Returns: - paths to directories that each contain contents of one zip file + List[str]: Paths to directories that each contain contents of one + zip file. """ def is_data_dir(uri: str) -> bool: @@ -1066,11 +1080,12 @@ def unzip_data(self, zip_uris: List[str], unzip_dir: str) -> List[str]: """Unzip dataset zip files. Args: - zip_uris (List[str]): a list of URIs of zip files: - unzip_dir (str): directory where zip files will be extrated to. + zip_uris (List[str]): A list of URIs of zip files: + unzip_dir (str): Directory where zip files will be extrated to. Returns: - paths to directories that each contain contents of one zip file + List[str]: Paths to directories that each contain contents of one + zip file. """ data_dirs = [] @@ -1413,7 +1428,7 @@ def build(self, model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None, loss_def_path: Optional[str] = None, - training=True) -> 'Learner': + training: bool = True) -> 'Learner': """Returns a Learner instantiated using this Config. Args: