Skip to content

Commit

Permalink
add typehints for torchvision.datasets.usps (pytorch#2538)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored and bryant1410 committed Nov 22, 2020
1 parent 3d69a9b commit 48a3a3c
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions torchvision/datasets/usps.py
@@ -1,6 +1,7 @@
from PIL import Image
import os
import numpy as np
from typing import Any, Callable, cast, Optional, Tuple

from .utils import download_url
from .vision import VisionDataset
Expand Down Expand Up @@ -36,8 +37,14 @@ class USPS(VisionDataset):
],
}

def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(USPS, self).__init__(root, transform=transform,
target_transform=target_transform)
split = 'train' if train else 'test'
Expand All @@ -52,13 +59,13 @@ def __init__(self, root, train=True, transform=None, target_transform=None,
raw_data = [line.decode().split() for line in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]

self.data = imgs
self.targets = targets

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Expand All @@ -80,5 +87,5 @@ def __getitem__(self, index):

return img, target

def __len__(self):
def __len__(self) -> int:
return len(self.data)

0 comments on commit 48a3a3c

Please sign in to comment.