forked from botkop/scorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
scrap.txt
308 lines (250 loc) · 10.1 KB
/
scrap.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
1:10: loss: 2.2915698289871216 accuracy: 0.1171875 duration: 88072 ms.
1:20: loss: 2.245325207710266 accuracy: 0.1796875 duration: 176792 ms.
1:30: loss: 2.2091324806213377 accuracy: 0.21875 duration: 263850 ms.
1:40: loss: 2.166573643684387 accuracy: 0.234375 duration: 350876 ms.
2018-05-01 11:14:19,273 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.287407875061035 accuracy: 0.171875 duration: 81427 ms.
2018-05-01 11:15:45,166 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.2399443626403808 accuracy: 0.1875 duration: 85893 ms.
2018-05-01 11:17:11,915 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:30: loss: 2.207755136489868 accuracy: 0.21875 duration: 86749 ms.
2018-05-01 11:18:38,611 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:40: loss: 2.1901093244552614 accuracy: 0.1796875 duration: 86696 ms.
batchSize = 128
par = 4
2018-05-01 11:33:04,512 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.293930506706238 accuracy: 0.140625 duration: 84542 ms.
2018-05-01 11:34:31,409 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.2760065317153932 accuracy: 0.1484375 duration: 86897 ms.
2018-05-01 11:35:59,488 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:30: loss: 2.246646547317505 accuracy: 0.1953125 duration: 88079 ms.
batchSize = 128
par = 8
2018-05-01 11:38:21,964 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.2916112661361696 accuracy: 0.1328125 duration: 83810 ms.
2018-05-01 11:39:50,174 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.270481753349304 accuracy: 0.1875 duration: 88210 ms.
2018-05-01 11:41:17,313 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:30: loss: 2.2637011051177978 accuracy: 0.171875 duration: 87139 ms.
batchSize = 256
par = 4
2018-05-01 11:44:56,798 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.2893950939178467 accuracy: 0.12109375 duration: 168209 ms.
2018-05-01 11:47:50,127 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.279911971092224 accuracy: 0.1328125 duration: 173329 ms.
batchSize = 256
par = 8
2018-05-01 11:52:20,983 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.275270128250122 accuracy: 0.15625 duration: 165857 ms.
2018-05-01 11:55:15,282 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.23012433052063 accuracy: 0.2109375 duration: 174299 ms.
package scorch.data.loader
import akka.actor.{Actor, ActorSystem, Props}
import akka.stream._
import akka.stream.scaladsl._
import akka.util.Timeout
import botkop.{numsca => ns}
import scorch._
import scorch.autograd.Variable
import scorch.data.loader.NetActor.{Ack, Complete, Init}
import scorch.nn.cnn.{Conv2d, MaxPool2d}
import scorch.nn.{Linear, Module}
import scorch.optim.Adam
import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.language.postfixOps
import scala.util.{Failure, Success}
class NetActor(net: Module, lossFunction: (Variable, Variable) => Variable)
extends Actor {
val optimizer = Adam(net.parameters, lr = 0.01)
override def receive: Receive = {
case (x: Variable, y: Variable) =>
println("received x y")
val yHat = net(x)
val loss = lossFunction(yHat, y)
val cost = loss.data.squeeze()
println(s"loss = $cost")
net.zeroGrad()
loss.backward()
println(net.parameters.size)
optimizer.step()
println("back prop done")
sender ! Ack
case _: Init.type =>
sender ! Ack
case Complete =>
println("done")
}
}
/*
class OptimActor(optimizer: Optimizer) extends Actor {
override def receive: Receive = {
case gradients: Seq[Variable] =>
optimizer.parameters.zip(gradients).foreach { case (p, g) =>
p.grad.data := g.data
}
optimizer.step()
sender ! optimizer.parameters
}
}
*/
object NetActor {
def props(net: Module) = Props(new NetActor(net, softmaxLoss))
case class Backward(loss: Variable)
case object Init
case object Ack
case object Complete
}
case class Net(batchSize: Int) extends Module {
val numChannels = 3
val imageSize = 32
val numClasses = 10
val inputShape = List(batchSize, numChannels, imageSize, imageSize)
val conv = Conv2d(numChannels = 3,
numFilters = 4,
filterSize = 5,
weightScale = 1e-3,
pad = 1,
stride = 1)
val pool = MaxPool2d(poolSize = 2, stride = 2)
val numFlatFeatures: Int =
pool.outputShape(conv.outputShape(inputShape)).tail.product
def flatten(v: Variable): Variable = v.reshape(batchSize, numFlatFeatures)
val fc = Linear(numFlatFeatures, numClasses)
override def forward(x: Variable): Variable =
x ~> conv ~> relu ~> pool ~> flatten ~> fc ~> relu
}
object ParallelLoader extends App {
val batchSize = 8
val loader = new Cifar10DataLoader(miniBatchSize = batchSize,
mode = "train",
take = Some(80))
implicit val system: ActorSystem = ActorSystem("scorch")
implicit val materializer: ActorMaterializer = ActorMaterializer()
implicit val askTimeout: Timeout = Timeout(60 seconds)
val net = Net(batchSize)
val netActor = system.actorOf(NetActor.props(net))
val source = Source(loader)
val sink = Sink.actorRefWithAck(netActor, Init, Ack, Complete)
source.runWith(sink)
}
object FlatLoader extends App {
import scala.concurrent.ExecutionContext.Implicits.global
val parallelism = 4
val batchSize = 16
val loader = new Cifar10DataLoader(miniBatchSize = batchSize,
mode = "train",
take = None)
val base = Net(batchSize)
val workers = Seq.fill(parallelism)(Net(batchSize))
val optimizer = Adam(base.parameters, lr = 0.001)
def add(as: Seq[Variable], bs: Seq[Variable]): Seq[Variable] =
as.zip(bs).map { case (a, b) => Variable(a.data + b.data) }
def pass(iteration: Int, worker: Net, x: Variable, y: Variable): Unit = {
println("pass")
val yHat = worker(x)
val loss = softmaxLoss(yHat, y)
println(s"iteration: $iteration loss = ${loss.data.squeeze()}")
loss.backward()
}
def updateBaseGradients(allGradients: Seq[Seq[Variable]]): Unit = {
val sums = allGradients.fold(base.gradients) {
case (a, b) => add(a, b)
}
val means = sums.map(_ / parallelism)
base.gradients.zip(means).foreach {
case (bg, m) =>
bg.data := m.data
}
}
loader.zipWithIndex
.sliding(parallelism, parallelism)
.map(_.toList)
.foreach { pb: List[((Variable, Variable), Int)] =>
base.zeroGrad()
val fs: Seq[Future[Seq[Variable]]] = workers
.zip(pb)
.map {
case (worker, ((x, y), ix)) =>
Future {
worker.zeroGrad()
pass(ix, worker, x, y)
worker.gradients
}
}
val results = Future.sequence(fs)
results.onComplete {
case Success(allGradients) =>
/*
val sums = allGradients.fold(base.gradients) {
case (a, b) => add(a, b)
}
val means = sums.map(_ / parallelism)
base.gradients.zip(means).foreach {
case (bg, m) =>
bg := m
}
*/
updateBaseGradients(allGradients)
optimizer.step()
workers.foreach { w =>
base.parameters.zip(w.parameters).foreach {
case (bp, wp) =>
wp.data := bp.data
}
}
println("step")
case Failure(ex) =>
throw new Exception(ex)
}
Await.ready(results, 20 seconds)
}
}
case class Parallelize(module: Module,
parallelism: Int,
timeOut: Duration = Duration.Inf)
extends Module {
// val workers: Seq[Module[Id]] = Seq.fill(parallelism)(module.clone().asInstanceOf[Module[Id]])
override def forward(x: Variable): Variable = ???
}
object Parallelize {
case class ParallelizeFunction(x: Variable,
baseModule: Module,
workerModules: Seq[Module],
timeOut: Duration = Duration.Inf)
extends scorch.autograd.Function {
import ns._
import scala.concurrent.ExecutionContext.Implicits.global
val parallelism: Int = workerModules.length
val batchSize: Int = x.shape.head
def scatter(v: Variable): Seq[Variable] =
(0 until batchSize)
.sliding(parallelism, parallelism)
.map(s => (s.head, s.last))
.map {
case (first, last) =>
Variable(v.data(first :> last))
}
.toSeq
val fs: Seq[Future[Variable]] = scatter(x).zip(workerModules).map {
case (v, worker) =>
Future {
workerModules.foreach { wm =>
// set parameters of all workers to parameters of base module
// set gradients of all workers to gradients of base module
wm.parameters.zip(baseModule.parameters).foreach {
case (wp, bp) =>
wp.data := bp.data
wp.grad.data := bp.grad.data
}
}
worker(v)
}
}
val activations: Seq[Variable] = Await.result(Future.sequence(fs), timeOut)
override def forward(): Variable = {
Variable(ns.concatenate(activations.map(_.data)), Some(this))
}
override def backward(gradOutput: Variable): Unit = {
val gs = scatter(gradOutput)
val fs = activations.zip(gs).map {
case (v, g) =>
Future(v.backward(g))
}
Await.result(Future.sequence(fs), timeOut)
// collect gradients from workers and accumulate in base module
workerModules.foreach { wm =>
baseModule.gradients.zip(wm.gradients).foreach {
case (bg, wg) =>
bg.data += wg.data
}
}
}
}
}