-
Notifications
You must be signed in to change notification settings - Fork 3
/
mentor.py
154 lines (133 loc) · 3.67 KB
/
mentor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import theano
import time
from lasagne.updates import rmsprop
from theano import tensor as T
import numpy as np
import numpy.random as rand
from inputFormat import *
from network import network
import matplotlib.pyplot as plt
import cPickle
import argparse
import os
def save():
print "saving network..."
if args.save:
save_name = args.save
else:
save_name = "mentor_network.save"
if args.data:
f = file(args.data+"/"+save_name, 'wb')
else:
f = file(save_name, 'wb')
cPickle.dump(network, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()
if args.data:
f = file(args.data+"/costs.save","wb")
cPickle.dump(costs, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()
parser = argparse.ArgumentParser()
parser.add_argument("--load", "-l", type=str, help="Specify a file with a prebuilt network to load.")
parser.add_argument("--save", "-s", type=str, help="Specify a file to save trained network to.")
parser.add_argument("--data", "-d", type =str, help="Specify a directory to save/load data for this run.")
args = parser.parse_args()
print "loading data... "
datafile = open("data/scoredPositionsFull.npz", 'r')
data = np.load(datafile)
positions = data['positions']
scores = data['scores']
if args.data:
if not os.path.exists(args.data):
os.makedirs(args.data)
costs = []
else:
if os.path.exists(args.data+"/costs.save"):
f = file(args.data+"/costs.save")
costs = cPickle.load(f)
f.close
else:
costs = []
else:
costs = []
datafile.close()
positions = positions.astype(theano.config.floatX)
scores = scores.astype(theano.config.floatX)
n_train = scores.shape[0]
positions_batch = T.tensor4('positions_batch')
y = T.tensor3('y') #target output score
numEpochs = 100
iteration = 0
batch_size = 64
numBatches = n_train/batch_size
#if load parameter is passed load a network from a file
if args.load:
print "loading model..."
f = file(args.load, 'rb')
network = cPickle.load(f)
if(network.batch_size):
batch_size = network.batch_size
f.close()
else:
print "building model..."
#use batchsize none now so that we can easily use same network for picking single moves and evaluating batches
network = network(batch_size=None)
cost = T.mean(T.sqr(network.output.reshape((batch_size, boardsize, boardsize)) - y))
alpha = 0.001
rho = 0.9
epsilon = 1e-6
updates = rmsprop(cost, network.params, alpha, rho, epsilon)
train_model = theano.function(
[positions_batch, y],
cost,
updates = updates,
givens={
network.input: positions_batch,
}
)
test_model = theano.function(
[positions_batch, y],
cost,
givens={
network.input: positions_batch,
}
)
evaluate_model = theano.function(
[positions_batch],
network.output,
givens={
network.input: positions_batch,
}
)
costs = []
print "Training model on mentor set..."
indices = range(n_train)
try:
for epoch in range(numEpochs):
print "epoch: ",epoch
np.random.shuffle(indices)
cost_sum = 0
for batch in range(numBatches):
t = time.clock()
p_batch = positions[indices[batch*batch_size:(batch+1)*batch_size]]
s_batch = scores[indices[batch*batch_size:(batch+1)*batch_size]]
cost=train_model(p_batch, s_batch)
run_time = time.clock()-t
cost_sum+=cost
iteration+=1
print "Cost: ",cost_sum/(batch+1), " Time per position: ", run_time/(batch_size)
costs.append(cost_sum/(batch+1))
plt.plot(costs)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.draw()
plt.pause(0.001)
#save snapshot of network every epoch in case something goes wrong
save()
except KeyboardInterrupt:
#save snapshot of network if we interrupt so we can pickup again later
save()
exit(1)
print "done training!"
save()
cPickle.dump(network, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()