-
Notifications
You must be signed in to change notification settings - Fork 24
/
cars.py
81 lines (67 loc) · 3.19 KB
/
cars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import scipy.io as sio
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torchvision.datasets.utils import extract_archive
class Cars(VisionDataset):
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset.
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
file_list = {
'imgs': ('http://imagenet.stanford.edu/internal/car196/car_ims.tgz', 'car_ims.tgz'),
'annos': ('http://imagenet.stanford.edu/internal/car196/cars_annos.mat', 'cars_annos.mat')
}
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(Cars, self).__init__(root, transform=transform, target_transform=target_transform)
self.loader = default_loader
self.train = train
if self._check_exists():
print('Files already downloaded and verified.')
elif download:
self._download()
else:
raise RuntimeError(
'Dataset not found. You can use download=True to download it.')
loaded_mat = sio.loadmat(os.path.join(self.root, self.file_list['annos'][1]))
loaded_mat = loaded_mat['annotations'][0]
self.samples = []
for item in loaded_mat:
if self.train != bool(item[-1][0]):
path = str(item[0][0])
label = int(item[-2][0]) - 1
self.samples.append((path, label))
def __getitem__(self, index):
path, target = self.samples[index]
path = os.path.join(self.root, path)
image = self.loader(path)
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
target = self.target_transform(target)
return image, target
def __len__(self):
return len(self.samples)
def _check_exists(self):
return (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
and os.path.exists(os.path.join(self.root, self.file_list['annos'][1])))
def _download(self):
print('Downloading...')
for url, filename in self.file_list.values():
download_url(url, root=self.root, filename=filename)
print('Extracting...')
archive = os.path.join(self.root, self.file_list['imgs'][1])
extract_archive(archive)
if __name__ == '__main__':
train_dataset = Cars('./cars', train=True, download=False)
test_dataset = Cars('./cars', train=False, download=False)