/
mnist_noisy.py
145 lines (115 loc) · 4.49 KB
/
mnist_noisy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
'''
Created on Oct 23, 2017
@author: longxiang
'''
import shutil
from torch.utils.data import Dataset
from mnist_sampler import *
np.random.seed(1234)
K_train = 5
K_test = 1
N_noisy_train = 6000 * K_train
N_noisy_test = 1000 * K_test
H = 28
W = 28
def generate_dataset(train=True, n_noisy=N_noisy_train, k=K_train):
datas = []
img_sampler = get_noisy_sampler()
num_sampler = get_number_sampler(load_mnist(train))
for i in range(10):
print 'Generate for number %d ...'%i
idx = 0
for _ in range(k):
for num in num_sampler.numbers[i]:
img = img_sampler.sample((H,W))
put_numbers(img, num)
datas.append((img, i))
idx += 1
if idx%10000 == 0:
print idx
print 'Generate for noisy ...'
for idx in range(n_noisy):
img = img_sampler.sample((H,W))
datas.append((img, 10))
if (idx+1)%10000 == 0:
print idx+1
return datas
def save_pkl(datas, path):
print 'save pkl to: %s'%path
train_imgs, train_labels = zip(*datas)
dump_pkl((train_imgs, train_labels), path)
def save_images(datas, image_dir):
print 'save images to: %s' %image_dir
if not os.path.exists(image_dir):
os.makedirs(image_dir)
idx = 0
for img, label in datas:
img = to_image(img)
img.save(os.path.join(image_dir, '%d_%06d.bmp'%(label, idx)))
idx += 1
class MNISTNoisy(Dataset):
training_images_root = 'noisy_train'
test_images_root = 'noisy_test'
training_file = 'noisy_train.pkl'
test_file = 'noisy_test.pkl'
def __init__(self, root=config.data_dir, train=True, transform=None, target_transform=None, generate=False, force_generate=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if generate:
self.generate(force_generate)
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use generate=True to generate it')
if self.train:
self.train_data, self.train_labels = load_pkl(
os.path.join(self.root, self.training_file))
else:
self.test_data, self.test_labels = load_pkl(
os.path.join(self.root, self.test_file))
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img = Image.fromarray(img, mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, '%d_%06d'%(target, index)
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.test_file))
def generate(self, force_generate):
if self._check_exists() and not force_generate:
return
datas = generate_dataset(train=True, n_noisy=N_noisy_train, k=K_train)
save_pkl(datas, os.path.join(self.root, self.training_file))
datas = generate_dataset(train=False, n_noisy=N_noisy_test, k=K_test)
save_pkl(datas, os.path.join(self.root, self.test_file))
def save_image(self):
if self.train:
image_root = os.path.join(config.data_dir, self.training_images_root)
if os.path.exists(image_root):
shutil.rmtree(image_root)
save_images(zip(self.train_data, self.train_labels), image_root)
else:
image_root = os.path.join(config.data_dir, self.test_images_root)
if os.path.exists(image_root):
shutil.rmtree(image_root)
save_images(zip(self.test_data, self.test_labels), image_root)
if __name__ == '__main__':
dataset = MNISTNoisy(train=True, generate=True, force_generate=True)
dataset = MNISTNoisy(train=True)
print len(dataset)
# mnist_multi.save_image()
dataset = MNISTNoisy(train=False)
print len(dataset)
# dataset.save_image()