-
Notifications
You must be signed in to change notification settings - Fork 3
/
data_loader.py
119 lines (103 loc) · 5.7 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import os
def getSVHN(batch_size, TF, data_root='data', train=True, val=True, **kwargs):
data_root = os.path.expanduser(os.path.join(data_root, 'svhn'))
kwargs.pop('input_size', None)
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
datasets.SVHN(root=data_root, split='train', download=True, transform=TF), batch_size=batch_size, shuffle=True)
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
datasets.SVHN(root=data_root, split='test', download=True, transform=TF,), batch_size=batch_size, shuffle=False)
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
def getCIFAR10(batch_size, TF, data_root='data', train=True, val=True, **kwargs):
data_root = os.path.expanduser(os.path.join(data_root, 'cifar10'))
kwargs.pop('input_size', None)
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=data_root, train=True, download=True, transform=TF), batch_size=batch_size, shuffle=True)
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=data_root, train=False, download=True, transform=TF), batch_size=batch_size, shuffle=False)
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
def getCIFAR100(batch_size, TF, data_root='data', TTF=None, train=True, val=True, **kwargs):
data_root = os.path.expanduser(os.path.join(data_root, 'cifar100'))
kwargs.pop('input_size', None)
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR100(root=data_root, train=True, download=True, transform=TF, target_transform=TTF), batch_size=batch_size, shuffle=True)
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR100(root=data_root, train=False, download=True, transform=TF, target_transform=TTF), batch_size=batch_size, shuffle=False)
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
def getIMAGENET1K(batch_size, TF, data_root='data', train=True, val=True, **kwargs):
data_root = os.path.expanduser(os.path.join(data_root, 'imagenet1k'))
kwargs.pop('input_size', None)
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
ImageFolder(os.path.join(data_root, 'train'), transform=TF),
batch_size=batch_size, shuffle=True)
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
ImageFolder(os.path.join(data_root, 'val'), transform=TF),
batch_size=batch_size, shuffle=False)
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
def getTargetDataSet(args, data_type, batch_size, input_TF, dataroot):
if data_type == 'cifar10':
train_loader, test_loader = getCIFAR10(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1)
elif data_type == 'cifar100':
train_loader, test_loader = getCIFAR100(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1)
elif data_type == 'imagenet1k':
train_loader, test_loader = getIMAGENET1K(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=8)
return train_loader, test_loader
def getNonTargetDataSet(args, data_type, batch_size, input_TF, dataroot):
if data_type == 'cifar10':
_, test_loader = getCIFAR10(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1)
elif data_type == 'cifar100':
_, test_loader = getCIFAR100(batch_size=batch_size, TF=input_TF, data_root=dataroot, TTF=lambda x: 0, num_workers=1)
elif data_type == 'svhn':
_, test_loader = getSVHN(batch_size=batch_size, TF=input_TF, data_root=dataroot, num_workers=1)
elif data_type == 'imagenet_resize':
dataroot = os.path.expanduser(os.path.join(dataroot, 'Imagenet_resize'))
testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False)
elif data_type == 'lsun_resize':
dataroot = os.path.expanduser(os.path.join(dataroot, 'LSUN_resize'))
testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False)
elif data_type == 'cifar10_64':
_, test_loader = getCIFAR10(batch_size=batch_size, TF=transforms.Compose([transforms.Resize((64, 64)), input_TF]),
data_root=dataroot, num_workers=1)
elif data_type == 'cifar100_64':
_, test_loader = getCIFAR100(batch_size=batch_size, TF=transforms.Compose([transforms.Resize((64, 64)), input_TF]),
data_root=dataroot, TTF=lambda x: 0, num_workers=1)
elif data_type == 'svhn_64':
_, test_loader = getSVHN(batch_size=batch_size, TF=transforms.Compose([transforms.Resize((64, 64)), input_TF]),
data_root=dataroot, num_workers=1)
elif data_type == 'imagenet-o-64':
dataroot = os.path.expanduser(os.path.join(dataroot, 'imagenet-o-64'))
testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False)
elif data_type == 'imagenet-o':
dataroot = os.path.expanduser(os.path.join('data', 'imagenet-o'))
testsetout = datasets.ImageFolder(dataroot, transform=input_TF)
test_loader = torch.utils.data.DataLoader(testsetout, batch_size=batch_size, shuffle=False, num_workers=8)
return test_loader