Skip to content

Commit

Permalink
im2col
Browse files Browse the repository at this point in the history
  • Loading branch information
koen-dejonghe committed May 1, 2018
1 parent 87f0f94 commit d8d0700
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 151 deletions.
36 changes: 36 additions & 0 deletions scrap.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,39 @@
1:10: loss: 2.2915698289871216 accuracy: 0.1171875 duration: 88072 ms.
1:20: loss: 2.245325207710266 accuracy: 0.1796875 duration: 176792 ms.
1:30: loss: 2.2091324806213377 accuracy: 0.21875 duration: 263850 ms.
1:40: loss: 2.166573643684387 accuracy: 0.234375 duration: 350876 ms.


2018-05-01 11:14:19,273 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.287407875061035 accuracy: 0.171875 duration: 81427 ms.
2018-05-01 11:15:45,166 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.2399443626403808 accuracy: 0.1875 duration: 85893 ms.
2018-05-01 11:17:11,915 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:30: loss: 2.207755136489868 accuracy: 0.21875 duration: 86749 ms.
2018-05-01 11:18:38,611 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:40: loss: 2.1901093244552614 accuracy: 0.1796875 duration: 86696 ms.

batchSize = 128
par = 4
2018-05-01 11:33:04,512 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.293930506706238 accuracy: 0.140625 duration: 84542 ms.
2018-05-01 11:34:31,409 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.2760065317153932 accuracy: 0.1484375 duration: 86897 ms.
2018-05-01 11:35:59,488 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:30: loss: 2.246646547317505 accuracy: 0.1953125 duration: 88079 ms.

batchSize = 128
par = 8
2018-05-01 11:38:21,964 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.2916112661361696 accuracy: 0.1328125 duration: 83810 ms.
2018-05-01 11:39:50,174 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.270481753349304 accuracy: 0.1875 duration: 88210 ms.
2018-05-01 11:41:17,313 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:30: loss: 2.2637011051177978 accuracy: 0.171875 duration: 87139 ms.

batchSize = 256
par = 4
2018-05-01 11:44:56,798 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.2893950939178467 accuracy: 0.12109375 duration: 168209 ms.
2018-05-01 11:47:50,127 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.279911971092224 accuracy: 0.1328125 duration: 173329 ms.

batchSize = 256
par = 8
2018-05-01 11:52:20,983 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:10: loss: 2.275270128250122 accuracy: 0.15625 duration: 165857 ms.
2018-05-01 11:55:15,282 [INFO] from scorch.sandbox.cnn.LeNet5$ in main - 1:20: loss: 2.23012433052063 accuracy: 0.2109375 duration: 174299 ms.




package scorch.data.loader

import akka.actor.{Actor, ActorSystem, Props}
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<appender-ref ref="STDOUT" />
</appender>

<logger name="scorch" level="DEBUG" />
<logger name="scorch" level="INFO" />
<logger name="scorch.TestUtil" level="DEBUG" />
<logger name="org.nd4j" level="INFO" />
<logger name="org.reflections" level="ERROR" />
Expand Down
5 changes: 4 additions & 1 deletion src/main/scala/scorch/data/loader/Cifar10DataLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class Cifar10DataLoader(mode: String,
(numSamples / miniBatchSize) +
(if (numSamples % miniBatchSize == 0) 0 else 1)

logger.info(s"number of samples: $numSamples")
logger.info(s"number of batches: $numBatches")

override def iterator: Iterator[(Variable, Variable)] = {
val batches: Iterator[List[File]] = new Random(seed)
.shuffle(files)
Expand All @@ -49,7 +52,7 @@ class Cifar10DataLoader(mode: String,
val batchSize = sampleFiles.length

// todo: maybe use akka streams here
val yxs = sampleFiles.par map deserialize
val yxs = sampleFiles map deserialize

val xData = yxs flatMap (_.x) toArray
val yData = yxs map (_.y) toArray
Expand Down
208 changes: 99 additions & 109 deletions src/main/scala/scorch/nn/cnn/Conv2d.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package scorch.nn.cnn

import botkop.{numsca => ns}
import botkop.numsca._
import botkop.{numsca => ns}
import com.typesafe.scalalogging.LazyLogging
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.convolution.Convolution
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.Nd4j.PadMode
import scorch.nn.Module
import scorch.autograd.{Function, Variable}
import scorch.nn.Module

import scala.collection.immutable
import scala.collection.parallel.immutable.ParSeq

case class Conv2d(w: Variable, b: Variable, pad: Int, stride: Int)
extends Module(Seq(w, b)) {
Expand All @@ -20,7 +20,6 @@ case class Conv2d(w: Variable, b: Variable, pad: Int, stride: Int)

override def forward(x: Variable): Variable =
Im2colConv2dFunction(x, w, b, pad, stride).forward()
// NaiveConv2dFunction(x, w, b, pad, stride).forward()
}

object Conv2d extends LazyLogging {
Expand Down Expand Up @@ -50,58 +49,6 @@ object Conv2d extends LazyLogging {
List(numSamples, numFilters, hPrime, wPrime)
}

/*
case class Im2colConv2dFunction(x: Variable,
w: Variable,
b: Variable,
pad: Int,
stride: Int)
extends Function {
val List(batchSize, numFilters, hPrime, wPrime) =
outputShape(x.shape, w.shape, pad, stride)
val List(kernelHeight, kernelWidth) = w.shape.takeRight(2)
override def forward(): Variable = {
val xCols: Tensor = new Tensor(
Convolution.im2col(x.data.array,
kernelHeight,
kernelWidth,
stride,
stride,
pad,
pad,
false))
println(x.shape)
println(x)
println()
println(xCols.shape.toList)
println(xCols)
println("!!!!!!!!!!!!!!!!!")
// println(xCols.transpose(1, 2, 3, 4, 5, 0).shape.toList)
// println(w.shape.toList)
val ws = w.data.reshape(w.shape.head, -1)
// val xt = xCols.transpose(3, 4, 5, 0, 1, 2).reshape(ws.shape.last, -1)
val xt = xCols.reshape(ws.shape.last, -1)
val res = ws.dot(xt) + b.data.reshape(-1, 1)
val out = res
.reshape(w.shape.head, hPrime, wPrime, x.shape.head)
.transpose(3, 0, 1, 2)
Variable(out, Some(this))
}
override def backward(gradOutput: Variable): Unit = ???
}
*/

case class NaiveConv2dFunction(x: Variable,
w: Variable,
b: Variable,
Expand Down Expand Up @@ -157,32 +104,6 @@ object Conv2d extends LazyLogging {
val dw = zerosLike(w.data)
val db = zerosLike(b.data)

/*
for {
n <- 0 until numDataPoints
dxPad = ns.pad(dx(n), padArea, PadMode.CONSTANT)
xPad = ns.pad(x.data(n), padArea, PadMode.CONSTANT)
f <- 0 until numFilters
wf = w.data(f)
hp <- 0 until hPrime
h1 = hp * stride
h2 = h1 + hh
wp <- 0 until wPrime
w1 = wp * stride
w2 = w1 + ww
} {
val d = dOut(n, f, hp, wp)
dxPad(:>, h1 :> h2, w1 :> w2) += wf * d
dw(f) += xPad(:>, h1 :> h2, w1 :> w2) * d
db(f) += d
dx(n) := dxPad(:>, 1 :> -1, 1 :> -1)
}
*/

(0 until numDataPoints).foreach { n =>
val dxPad = ns.pad(dx(n), padArea, PadMode.CONSTANT)
val xPad = ns.pad(x.data(n), padArea, PadMode.CONSTANT)
Expand All @@ -201,11 +122,10 @@ object Conv2d extends LazyLogging {
dxPad(:>, h1 :> h2, w1 :> w2) += wf * d
dw(f) += xPad(:>, h1 :> h2, w1 :> w2) * d
db(f) += d

dx(n) := dxPad(:>, 1 :> -1, 1 :> -1)
}
}
}
dx(n) := dxPad(:>, 1 :> -1, 1 :> -1)
}

x.backward(Variable(dx))
Expand All @@ -229,23 +149,18 @@ object Conv2d extends LazyLogging {

val Array(c, h, w) = x.shape
val newH = (h - hh) / stride + 1
val newW = w - ww / stride + 1

val newW = (w - ww) / stride + 1

val col = ns.zeros(newH * newW, c * hh * ww)
println(col.shape.toList)
println(x.shape.toList)

for {
i <- 0 until newH
j <- 0 until newW
} {
val patch =
x(:>, (i * stride) :> (i * stride + hh), (j * stride) :> (j * stride + ww))


println(col(i * newW + j).shape.toList)
println(patch.shape.toList)
x(:>,
(i * stride) :> (i * stride + hh),
(j * stride) :> (j * stride + ww))

col(i * newW + j) := patch.reshape(1, -1)
}
Expand Down Expand Up @@ -284,35 +199,110 @@ object Conv2d extends LazyLogging {
}
}

/**
* @param dimCol gradients for imCol,(hPrime * wPrime, hh * ww *c)
* @param hPrime height for the feature map
* @param wPrime width for the feature map
* @param stride stride
* @param hh filter height
* @param ww filter width
* @param c number of channels
* @return gradients for x, (C,H,W)
*/
def col2imBack(dimCol: Tensor,
hPrime: Int,
wPrime: Int,
stride: Int,
hh: Int,
ww: Int,
c: Int): Tensor = {

val h = (hPrime - 1) * stride + hh
val w = (wPrime - 1) * stride + ww
val dx = ns.zeros(c, h, w)

for (i <- 0 until hPrime * wPrime) {
val row = dimCol(i)
val hStart = (i / wPrime) * stride
val wStart = (i % wPrime) * stride
dx(:>, hStart :> hStart + hh, wStart :> wStart + ww) +=
ns.reshape(row, c, hh, ww)
}
dx
}

case class Im2colConv2dFunction(x: Variable,
w: Variable,
b: Variable,
pad: Int,
stride: Int)
extends Function {

override def forward(): Variable = {
val List(batchSize, c, height, width) = x.shape
val List(f, _, hh, ww) = w.shape
private val initStart = System.currentTimeMillis()

val hPrime = (height + 2 * pad - hh) / stride + 1
val wPrime = (width + 2 * pad - ww) / stride + 1
val List(batchSize, c, height, width) = x.shape
val List(f, _, hh, ww) = w.shape

val padArea = Array(Array(0, 0), Array(pad, pad), Array(pad, pad))
val hPrime: Int = (height + 2 * pad - hh) / stride + 1
val wPrime: Int = (width + 2 * pad - ww) / stride + 1

val out = ns.zeros(batchSize, f, hPrime, wPrime)
val padArea = Array(Array(0, 0), Array(pad, pad), Array(pad, pad))

val imCols: immutable.Seq[(Int, Tensor)] = for {
imNum <- 0 until batchSize
im = x.data(imNum)
imPad = ns.pad(im, padArea, PadMode.CONSTANT)
col = im2col(imPad, hh, ww, stride)
} yield {
(imNum, col)
}

val filterCol: Tensor = w.data.reshape(f, -1).T

private val initEnd = System.currentTimeMillis()
logger.debug(s"initialization took ${initEnd - initStart} ms.")

for (imNum <- 0 until batchSize) {
val im = x.data(imNum)
val imPad = ns.pad(im, padArea, PadMode.CONSTANT)
val imCol = im2col(imPad, hh, ww, stride)
val filterCol = w.data.reshape(f, -1).T
val mul = imCol.dot(filterCol.T) + b.data
override def forward(): Variable = {
val fwdStart = System.currentTimeMillis()
val out = ns.zeros(batchSize, f, hPrime, wPrime)
for ((imNum, imCol) <- imCols) {
val mul = imCol.dot(filterCol) + b.data
out(imNum) := col2im(mul, hPrime, wPrime, 1)
}
val fwdEnd = System.currentTimeMillis()
logger.debug(s"forward pass took ${fwdEnd - fwdStart} ms.")
Variable(out, Some(this))
}

override def backward(gradOutput: Variable): Unit = ???
override def backward(gradOutput: Variable): Unit = {

val bwStart = System.currentTimeMillis()

val dOut = gradOutput.data

val dx = zerosLike(x.data)
val dw = zerosLike(w.data)
val db = zerosLike(b.data)

for ((i, imCol) <- imCols) {
val dMul = ns.reshape(dOut(i), f, -1).T
db += ns.sum(dMul, axis = 0)

val dFilterCol = imCol.T.dot(dMul)
val dimCol = dMul.dot(filterCol.T)

val dxPadded = col2imBack(dimCol, hPrime, wPrime, stride, hh, ww, c)
dx(i) := dxPadded(:>, pad :> height + pad, pad :> width + pad)

dw += ns.reshape(dFilterCol.T, f, c, hh, ww)
}

x.backward(Variable(dx))
w.backward(Variable(dw))
b.backward(Variable(db))

val bwEnd = System.currentTimeMillis()
logger.debug(s"backward pass took ${bwEnd - bwStart} ms.")
}
}
}
17 changes: 8 additions & 9 deletions src/main/scala/scorch/nn/cnn/MaxPool2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ object MaxPool2d {
poolHeight: Int,
poolWidth: Int,
stride: Int): List[Int] = {
val List(numDataPoints, numChannels, height, width) = inputShape
val hPrime: Int = 1 + (height - poolHeight) / stride
val wPrime: Int = 1 + (width - poolWidth) / stride
List(numDataPoints, numChannels, hPrime, wPrime)
val List(n, c, h, w) = inputShape
val hPrime = 1 + (h - poolHeight) / stride
val wPrime = 1 + (w - poolWidth) / stride
List(n, c, hPrime, wPrime)
}

case class NaiveMaxPool2dFunction(x: Variable,
Expand All @@ -38,16 +38,15 @@ object MaxPool2d {
stride: Int)
extends Function {

val List(numDataPoints, numChannels, hPrime, wPrime) =
val List(batchSize, numChannels, hPrime, wPrime) =
outputShape(x.shape, poolHeight, poolWidth, stride)

override def forward(): Variable = {

val out = ns.zeros(numDataPoints, numChannels, hPrime, wPrime)
val out = ns.zeros(batchSize, numChannels, hPrime, wPrime)

for {
n <- 0 until numDataPoints
c <- 0 until numChannels
n <- 0 until batchSize

h <- 0 until hPrime
h1 = h * stride
Expand All @@ -71,7 +70,7 @@ object MaxPool2d {
val dOut = gradOutput.data

for {
n <- 0 until numDataPoints
n <- 0 until batchSize
c <- 0 until numChannels

h <- 0 until hPrime
Expand Down
Loading

0 comments on commit d8d0700

Please sign in to comment.