Skip to content

Commit

Permalink
model: refine and test cases for FeedForward constructor (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Nov 27, 2017
1 parent 935eb35 commit 2a5a284
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
9 changes: 2 additions & 7 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mutable struct FeedForward <: AbstractModel

# leave the rest fields undefined
FeedForward(arch :: SymbolicNode, ctx :: Vector{Context}) = new(arch, ctx)
FeedForward(arch :: SymbolicNode, ctx :: Context) = new(arch, [ctx])
end

"""
Expand Down Expand Up @@ -53,14 +54,8 @@ end
or a list of `Context` objects. In the latter case, data parallelization will be used
for training. If no context is provided, the default context `cpu()` will be used.
"""
function FeedForward(arch :: SymbolicNode; context :: Union{Context, Vector{Context}, Void} = nothing)
if isa(context, Void)
context = [Context(CPU)]
elseif isa(context, Context)
context = [context]
end
FeedForward(arch::SymbolicNode; context::Union{Context,Vector{Context}} = [cpu()]) =
FeedForward(arch, context)
end

"""
init_model(self, initializer; overwrite=false, input_shapes...)
Expand Down
34 changes: 34 additions & 0 deletions test/unittest/model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module TestModel

using Base.Test
using MXNet


function test_feedforward()
info("Model::FeedForward::constructor")
let x = @mx.var x
m = mx.FeedForward(x)
@assert m.arch === x
@assert length(m.ctx) == 1
end

info("Model::FeedForward::constructor::keyword context")
let x = @mx.var x
m = mx.FeedForward(x, context = mx.cpu())
@assert m.arch === x
@assert length(m.ctx) == 1
end

let x = @mx.var x
m = mx.FeedForward(x, context = [mx.cpu(), mx.cpu(1)])
@assert m.arch === x
@assert length(m.ctx) == 2
end
end


@testset "Model Test" begin
test_feedforward()
end

end # module TestModel

0 comments on commit 2a5a284

Please sign in to comment.