Skip to content

Commit

Permalink
add typehints to torchvision.datasets.sbu (pytorch#2536)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored and bryant1410 committed Nov 22, 2020
1 parent 79cc138 commit 314629a
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions torchvision/datasets/sbu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from PIL import Image
from .utils import download_url, check_integrity
from typing import Any, Callable, Optional, Tuple

import os
from .vision import VisionDataset
Expand All @@ -23,7 +24,13 @@ class SBU(VisionDataset):
filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285'

def __init__(self, root, transform=None, target_transform=None, download=True):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
super(SBU, self).__init__(root, transform=transform,
target_transform=target_transform)

Expand All @@ -50,7 +57,7 @@ def __init__(self, root, transform=None, target_transform=None, download=True):
self.photos.append(photo)
self.captions.append(caption)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Expand All @@ -69,19 +76,19 @@ def __getitem__(self, index):

return img, target

def __len__(self):
def __len__(self) -> int:
"""The number of photos in the dataset."""
return len(self.photos)

def _check_integrity(self):
def _check_integrity(self) -> bool:
"""Check the md5 checksum of the downloaded tarball."""
root = self.root
fpath = os.path.join(root, self.filename)
if not check_integrity(fpath, self.md5_checksum):
return False
return True

def download(self):
def download(self) -> None:
"""Download and extract the tarball, and download each individual photo."""
import tarfile

Expand Down

0 comments on commit 314629a

Please sign in to comment.