Skip to content

Commit

Permalink
add typehints for torchvision.datasets.mnist (pytorch#2532)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored and bryant1410 committed Nov 22, 2020
1 parent 314629a commit 3d69a9b
Showing 1 changed file with 45 additions and 34 deletions.
79 changes: 45 additions & 34 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import codecs
import string
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from .utils import download_url, download_and_extract_archive, extract_archive, \
verify_str_arg

Expand Down Expand Up @@ -60,8 +61,14 @@ def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data

def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
Expand All @@ -79,7 +86,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None,
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Expand All @@ -101,28 +108,28 @@ def __getitem__(self, index):

return img, target

def __len__(self):
def __len__(self) -> int:
return len(self.data)

@property
def raw_folder(self):
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw')

@property
def processed_folder(self):
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'processed')

@property
def class_to_idx(self):
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}

def _check_exists(self):
def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))

def download(self):
def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already."""

if self._check_exists():
Expand Down Expand Up @@ -154,7 +161,7 @@ def download(self):

print('Done!')

def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")


Expand Down Expand Up @@ -251,22 +258,22 @@ class EMNIST(MNIST):
'mnist': list(string.digits),
}

def __init__(self, root, split, **kwargs):
def __init__(self, root: str, split: str, **kwargs: Any) -> None:
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)
self.classes = self.classes_split_dict[self.split]

@staticmethod
def _training_file(split):
def _training_file(split) -> str:
return 'training_{}.pt'.format(split)

@staticmethod
def _test_file(split):
def _test_file(split) -> str:
return 'test_{}.pt'.format(split)

def download(self):
def download(self) -> None:
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil

Expand Down Expand Up @@ -343,7 +350,7 @@ class QMNIST(MNIST):
'test50k': 'test',
'nist': 'nist'
}
resources = { # type: ignore[assignment]
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'ed72d4157d28c017586c42bc6afe6370'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz',
Expand All @@ -360,7 +367,10 @@ class QMNIST(MNIST):
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

def __init__(self, root, what=None, compat=True, train=True, **kwargs):
def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True,
train: bool = True, **kwargs: Any
) -> None:
if what is None:
what = 'train' if train else 'test'
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
Expand All @@ -370,7 +380,7 @@ def __init__(self, root, what=None, compat=True, train=True, **kwargs):
self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs)

def download(self):
def download(self) -> None:
"""Download the QMNIST data if it doesn't exist in processed_folder already.
Note that we only download what has been asked for (argument 'what').
"""
Expand Down Expand Up @@ -405,7 +415,7 @@ def download(self):
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
torch.save((data, targets), f)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode='L')
Expand All @@ -417,15 +427,15 @@ def __getitem__(self, index):
target = self.target_transform(target)
return img, target

def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {}".format(self.what)


def get_int(b):
def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16)


def open_maybe_compressed_file(path):
def open_maybe_compressed_file(path: Union[str, IO]) -> IO:
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
"""
Expand All @@ -440,19 +450,20 @@ def open_maybe_compressed_file(path):
return open(path, 'rb')


def read_sn3_pascalvincent_tensor(path, strict=True):
SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')
}


def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
read_sn3_pascalvincent_tensor.typemap = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')}
# read
with open_maybe_compressed_file(path) as f:
data = f.read()
Expand All @@ -462,22 +473,22 @@ def read_sn3_pascalvincent_tensor(path, strict=True):
ty = magic // 256
assert nd >= 1 and nd <= 3
assert ty >= 8 and ty <= 14
m = read_sn3_pascalvincent_tensor.typemap[ty]
m = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


def read_label_file(path):
def read_label_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 1)
return x.long()


def read_image_file(path):
def read_image_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8)
Expand Down

0 comments on commit 3d69a9b

Please sign in to comment.