Skip to content

Commit

Permalink
provide an optional context in mx.load_checkpoint (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy authored and pluskid committed Mar 29, 2017
1 parent 1781290 commit 9474d95
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,14 @@ function load_checkpoint(prefix :: AbstractString, epoch :: Int)
return (arch, arg_params, aux_params)
end

function load_checkpoint(prefix :: AbstractString, epoch :: Int, ::Type{FeedForward})
"""
load_checkpoint(prefix, epoch, ::mx.FeedForward; context)
Load a mx.FeedForward model from the checkpoint *prefix*, *epoch* and optionally provide a context.
"""
function load_checkpoint(prefix :: AbstractString, epoch :: Int, ::Type{FeedForward}; context = nothing)
arch, arg_params, aux_params = load_checkpoint(prefix, epoch)
model = FeedForward(arch)
model = FeedForward(arch, context = context)
model.arg_params = arg_params
model.aux_params = aux_params
return model
Expand Down

0 comments on commit 9474d95

Please sign in to comment.