/
train.py
77 lines (62 loc) · 2.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os
import sys
sys.path.append('../../..')
import numpy
from anna import util
from anna.datasets import supervised_dataset
from model import SupervisedModel
print('Start')
parser = argparse.ArgumentParser(prog='train_finetune_random', description='Script to train deconvolutional network from random initialization.')
parser.add_argument("-s", "--split", default='0', help='Training split of stl10 to use. (0-9)')
args = parser.parse_args()
train_split = int(args.split)
if train_split < 0 or train_split > 9:
raise Exception("Training Split must be in range 0-9.")
print('Using STL10 training split: {}'.format(train_split))
pid = os.getpid()
print('PID: {}'.format(pid))
f = open('pid_'+str(train_split), 'wb')
f.write(str(pid)+'\n')
f.close()
model = SupervisedModel('experiment', './', learning_rate=1e-2)
monitor = util.Monitor(model, checkpoint_directory='checkpoints_'+str(train_split))
# Loading STL-10 dataset
print('Loading Data')
X_train = numpy.load('/data/stl10_matlab/train_splits/train_X_'+str(train_split)+'.npy')
y_train = numpy.load('/data/stl10_matlab/train_splits/train_y_'+str(train_split)+'.npy')
X_test = numpy.load('/data/stl10_matlab/test_X.npy')
y_test = numpy.load('/data/stl10_matlab/test_y.npy')
X_train = numpy.float32(X_train)
X_train /= 255.0
X_train *= 2.0
X_test = numpy.float32(X_test)
X_test /= 255.0
X_test *= 2.0
train_dataset = supervised_dataset.SupervisedDataset(X_train, y_train)
test_dataset = supervised_dataset.SupervisedDataset(X_test, y_test)
train_iterator = train_dataset.iterator(
mode='random_uniform', batch_size=128, num_batches=45000)
test_iterator = test_dataset.iterator(
mode='random_uniform', batch_size=128, num_batches=45000)
# Create object to local contrast normalize a batch.
# Note: Every batch must be normalized before use.
normer = util.Normer2(filter_size=5, num_channels=3)
augmenter = util.DataAugmenter(16, (96, 96))
print('Training Model')
for x_batch, y_batch in train_iterator:
x_batch = x_batch.transpose(1, 2, 3, 0)
x_batch = augmenter.run(x_batch)
x_batch = normer.run(x_batch)
# y_batch = numpy.int64(numpy.argmax(y_batch, axis=1))
monitor.start()
log_prob, accuracy = model.train(x_batch, y_batch-1)
monitor.stop(1-accuracy) # monitor takes error instead of accuracy
if monitor.test:
monitor.start()
x_test_batch, y_test_batch = test_iterator.next()
x_test_batch = x_test_batch.transpose(1, 2, 3, 0)
x_test_batch = normer.run(x_test_batch)
# y_test_batch = numpy.int64(numpy.argmax(y_test_batch, axis=1))
test_accuracy = model.eval(x_test_batch, y_test_batch-1)
monitor.stop_test(1-test_accuracy)