From 3d69a9be4e9a2b13d87b79b339349e9efd1b40fc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 3 Aug 2020 11:00:48 +0200 Subject: [PATCH] add typehints for torchvision.datasets.mnist (#2532) --- torchvision/datasets/mnist.py | 79 ++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 4bb09955e70..2c037270ed9 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -7,6 +7,7 @@ import torch import codecs import string +from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union from .utils import download_url, download_and_extract_archive, extract_archive, \ verify_str_arg @@ -60,8 +61,14 @@ def test_data(self): warnings.warn("test_data has been renamed data") return self.data - def __init__(self, root, train=True, transform=None, target_transform=None, - download=False): + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform) self.train = train # training set or test set @@ -79,7 +86,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None, data_file = self.test_file self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index @@ -101,28 +108,28 @@ def __getitem__(self, index): return img, target - def __len__(self): + def __len__(self) -> int: return len(self.data) @property - def raw_folder(self): + def raw_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__, 'raw') @property - def processed_folder(self): + def processed_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__, 'processed') @property - def class_to_idx(self): + def class_to_idx(self) -> Dict[str, int]: return {_class: i for i, _class in enumerate(self.classes)} - def _check_exists(self): + def _check_exists(self) -> bool: return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and os.path.exists(os.path.join(self.processed_folder, self.test_file))) - def download(self): + def download(self) -> None: """Download the MNIST data if it doesn't exist in processed_folder already.""" if self._check_exists(): @@ -154,7 +161,7 @@ def download(self): print('Done!') - def extra_repr(self): + def extra_repr(self) -> str: return "Split: {}".format("Train" if self.train is True else "Test") @@ -251,7 +258,7 @@ class EMNIST(MNIST): 'mnist': list(string.digits), } - def __init__(self, root, split, **kwargs): + def __init__(self, root: str, split: str, **kwargs: Any) -> None: self.split = verify_str_arg(split, "split", self.splits) self.training_file = self._training_file(split) self.test_file = self._test_file(split) @@ -259,14 +266,14 @@ def __init__(self, root, split, **kwargs): self.classes = self.classes_split_dict[self.split] @staticmethod - def _training_file(split): + def _training_file(split) -> str: return 'training_{}.pt'.format(split) @staticmethod - def _test_file(split): + def _test_file(split) -> str: return 'test_{}.pt'.format(split) - def download(self): + def download(self) -> None: """Download the EMNIST data if it doesn't exist in processed_folder already.""" import shutil @@ -343,7 +350,7 @@ class QMNIST(MNIST): 'test50k': 'test', 'nist': 'nist' } - resources = { # type: ignore[assignment] + resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] 'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz', 'ed72d4157d28c017586c42bc6afe6370'), ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz', @@ -360,7 +367,10 @@ class QMNIST(MNIST): classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] - def __init__(self, root, what=None, compat=True, train=True, **kwargs): + def __init__( + self, root: str, what: Optional[str] = None, compat: bool = True, + train: bool = True, **kwargs: Any + ) -> None: if what is None: what = 'train' if train else 'test' self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) @@ -370,7 +380,7 @@ def __init__(self, root, what=None, compat=True, train=True, **kwargs): self.test_file = self.data_file super(QMNIST, self).__init__(root, train, **kwargs) - def download(self): + def download(self) -> None: """Download the QMNIST data if it doesn't exist in processed_folder already. Note that we only download what has been asked for (argument 'what'). """ @@ -405,7 +415,7 @@ def download(self): with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f: torch.save((data, targets), f) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple[Any, Any]: # redefined to handle the compat flag img, target = self.data[index], self.targets[index] img = Image.fromarray(img.numpy(), mode='L') @@ -417,15 +427,15 @@ def __getitem__(self, index): target = self.target_transform(target) return img, target - def extra_repr(self): + def extra_repr(self) -> str: return "Split: {}".format(self.what) -def get_int(b): +def get_int(b: bytes) -> int: return int(codecs.encode(b, 'hex'), 16) -def open_maybe_compressed_file(path): +def open_maybe_compressed_file(path: Union[str, IO]) -> IO: """Return a file object that possibly decompresses 'path' on the fly. Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. """ @@ -440,19 +450,20 @@ def open_maybe_compressed_file(path): return open(path, 'rb') -def read_sn3_pascalvincent_tensor(path, strict=True): +SN3_PASCALVINCENT_TYPEMAP = { + 8: (torch.uint8, np.uint8, np.uint8), + 9: (torch.int8, np.int8, np.int8), + 11: (torch.int16, np.dtype('>i2'), 'i2'), + 12: (torch.int32, np.dtype('>i4'), 'i4'), + 13: (torch.float32, np.dtype('>f4'), 'f4'), + 14: (torch.float64, np.dtype('>f8'), 'f8') +} + + +def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). Argument may be a filename, compressed filename, or file object. """ - # typemap - if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'): - read_sn3_pascalvincent_tensor.typemap = { - 8: (torch.uint8, np.uint8, np.uint8), - 9: (torch.int8, np.int8, np.int8), - 11: (torch.int16, np.dtype('>i2'), 'i2'), - 12: (torch.int32, np.dtype('>i4'), 'i4'), - 13: (torch.float32, np.dtype('>f4'), 'f4'), - 14: (torch.float64, np.dtype('>f8'), 'f8')} # read with open_maybe_compressed_file(path) as f: data = f.read() @@ -462,14 +473,14 @@ def read_sn3_pascalvincent_tensor(path, strict=True): ty = magic // 256 assert nd >= 1 and nd <= 3 assert ty >= 8 and ty <= 14 - m = read_sn3_pascalvincent_tensor.typemap[ty] + m = SN3_PASCALVINCENT_TYPEMAP[ty] s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) assert parsed.shape[0] == np.prod(s) or not strict return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) -def read_label_file(path): +def read_label_file(path: str) -> torch.Tensor: with open(path, 'rb') as f: x = read_sn3_pascalvincent_tensor(f, strict=False) assert(x.dtype == torch.uint8) @@ -477,7 +488,7 @@ def read_label_file(path): return x.long() -def read_image_file(path): +def read_image_file(path: str) -> torch.Tensor: with open(path, 'rb') as f: x = read_sn3_pascalvincent_tensor(f, strict=False) assert(x.dtype == torch.uint8)