Skip to content

Commit

Permalink
test: dilated serialization test
Browse files Browse the repository at this point in the history
  • Loading branch information
i8run committed Nov 14, 2017
1 parent 6c8a9ee commit 71498f0
Showing 1 changed file with 32 additions and 0 deletions.
Expand Up @@ -17,6 +17,7 @@ package com.intel.analytics.bigdl.utils.serializer

import com.google.protobuf.{ByteString, CodedOutputStream}
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.models.lenet.LeNet5
import com.intel.analytics.bigdl.nn.abstractnn.TensorModule
import com.intel.analytics.bigdl.nn.ops.ParseExample
import com.intel.analytics.bigdl.nn.{VolumetricFullConvolution, _}
Expand Down Expand Up @@ -1883,6 +1884,37 @@ class ModuleSerializerSpec extends FlatSpec with Matchers {
res1 should be (res2)
}

"bigquant.SpatialDilatedConvolution serializer" should "work properly " in {
val nInputPlane = 1
val nOutputPlane = 1
val kW = 2
val kH = 2
val dW = 1
val dH = 1
val padW = 0
val padH = 0

val kernelData = Array(
2.0f, 3f,
4f, 5f
)

val biasData = Array(0.0f)

val input = Tensor(1, 1, 3, 3).apply1(_ => Random.nextFloat())
val weight = Tensor(Storage(kernelData), 1, Array(nOutputPlane, nInputPlane, kH, kW))
val bias = Tensor(Storage(biasData), 1, Array(nOutputPlane))
val conv = quantized.SpatialDilatedConvolution[Float](nInputPlane, nOutputPlane,
kW, kH, dW, dH, padW, padH, initWeight = weight, initBias = bias)

val res1 = conv.forward(input)

ModulePersister.saveToFile("/tmp/bigquant.dilated.conv.bigdl", conv, true)
val loadedConv = ModuleLoader.loadFromFile("/tmp/bigquant.dilated.conv.bigdl")
val res2 = loadedConv.forward(input)
res1 should be (res2)
}

"bigquant.Linear serializer" should "work properly " in {
val outputSize = 2
val inputSize = 2
Expand Down

0 comments on commit 71498f0

Please sign in to comment.