/
make_cifar10_whitened.py
72 lines (52 loc) · 2.24 KB
/
make_cifar10_whitened.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
This script makes a dataset of 32x32 approximately whitened CIFAR-10 images.
"""
from __future__ import print_function
from pylearn2.utils import serial
from pylearn2.datasets import preprocessing
from pylearn2.utils import string_utils
import numpy as np
from pylearn2.datasets.cifar10 import CIFAR10
import textwrap
def main():
data_dir = string_utils.preprocess('${PYLEARN2_DATA_PATH}/cifar10')
print('Loading CIFAR-10 train dataset...')
train = CIFAR10(which_set='train')
print("Preparing output directory...")
output_dir = data_dir + '/pylearn2_whitened'
serial.mkdir(output_dir)
README = open(output_dir + '/README', 'w')
README.write(textwrap.dedent("""
The .pkl files in this directory may be opened in python using
cPickle, pickle, or pylearn2.serial.load.
train.pkl, and test.pkl each contain
a pylearn2 Dataset object defining a labeled
dataset of a 32x32 approximately whitened version of the STL-10
dataset. train.pkl contains labeled train examples. test.pkl
contains labeled test examples.
preprocessor.pkl contains a pylearn2 ZCA object that was used
to approximately whiten the images. You may want to use this
object later to preprocess other images.
They were created with the pylearn2 script make_cifar10_whitened.py.
All other files in this directory, including this README, were
created by the same script and are necessary for the other files
to function correctly.
"""))
README.close()
print("Learning the preprocessor and preprocessing \
the unsupervised train data...")
preprocessor = preprocessing.ZCA()
train.apply_preprocessor(preprocessor=preprocessor, can_fit=True)
print('Saving the unsupervised data')
train.use_design_loc(output_dir+'/train.npy')
serial.save(output_dir + '/train.pkl', train)
print("Loading the test data")
test = CIFAR10(which_set='test')
print("Preprocessing the test data")
test.apply_preprocessor(preprocessor=preprocessor, can_fit=False)
print("Saving the test data")
test.use_design_loc(output_dir+'/test.npy')
serial.save(output_dir+'/test.pkl', test)
serial.save(output_dir + '/preprocessor.pkl', preprocessor)
if __name__ == "__main__":
main()