Skip to content

Commit

Permalink
Merge pull request #2970 from rcurtin/julia-serialize-len
Browse files Browse the repository at this point in the history
Julia: serialize the length of the model.
  • Loading branch information
zoq committed Jun 9, 2021
2 parents 0259155 + d0e8351 commit 66eecdc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
* Fixes to `HoeffdingTree`: ensure that training still works when empty
constructor is used (#2964).

* Fix Julia model serialization bug (#2970).

### mlpack 3.4.2
###### 2020-10-26
* Added Mean Absolute Percentage Error.
Expand Down
10 changes: 7 additions & 3 deletions src/mlpack/bindings/julia/print_param_defn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ void PrintParamDefn(
// buffer = ccall((:Serialize<Type>Ptr, <programName>Library),
// Vector{UInt8}, (Ptr{Nothing}, Ptr{UInt8}), model.ptr,
// Base.pointer(buf_len))
// buf = Base.unsafe_wrap(buf_ptr, buf_len[0]; own=true)
// buf = Base.unsafe_wrap(buf_ptr, buf_len[1]; own=true)
// write(stream, buf_len[1])
// write(stream, buf)
// end
//
// function deserialize<Type>(stream::IO)::<Type>
// buffer = read(stream)
// buf_len = read(stream, UInt)
// buffer = read(stream, buf_len)
// <Type>(ccall((:Deserialize<Type>Ptr, <programName>Library),
// Ptr{Nothing}, (Vector{UInt8}, UInt), buffer, length(buffer)))
// end
Expand Down Expand Up @@ -138,14 +140,16 @@ void PrintParamDefn(
<< "Base.pointer(buf_len))" << std::endl;
std::cout << " buf = Base.unsafe_wrap(Vector{UInt8}, buf_ptr, buf_len[1]; "
<< "own=true)" << std::endl;
std::cout << " write(stream, buf_len[1])" << std::endl;
std::cout << " write(stream, buf)" << std::endl;
std::cout << "end" << std::endl;

// And the deserialization functionality.
std::cout << "# Deserialize a model from the given stream." << std::endl;
std::cout << "function deserialize" << type << "(stream::IO)::" << type
<< std::endl;
std::cout << " buffer = read(stream)" << std::endl;
std::cout << " buf_len = read(stream, UInt)" << std::endl;
std::cout << " buffer = read(stream, buf_len)" << std::endl;
std::cout << " " << type << "(ccall((:Deserialize" << type << "Ptr, "
<< programName << "Library), Ptr{Nothing}, (Ptr{UInt8}, UInt), "
<< "Base.pointer(buffer), length(buffer)))" << std::endl;
Expand Down
21 changes: 21 additions & 0 deletions src/mlpack/bindings/julia/tests/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,27 @@ end
model_in=newModel)
end

# Test that we can serialize a model as part of a larger tuple.
@testset "TestStreamTupleSerialization" begin
_, _, _, _, _, _, modelOut, _, _, _, _, _, _, _ =
test_julia_binding(4.0, 12, "hello",
build_model=true)

stream = IOBuffer()
serialize(stream, (modelOut, 3, 4, 5))

newStream = IOBuffer(copy(stream.data))
(newModel, a, b, c) = deserialize(newStream)

_, _, _, _, _, bwOut, _, _, _, _, _, _, _, _ =
test_julia_binding(4.0, 12, "hello",
model_in=newModel)

@test a == 3
@test b == 4
@test c == 5
end

@testset "TestFileSerialization" begin
_, _, _, _, _, _, modelOut, _, _, _, _, _, _, _ =
test_julia_binding(4.0, 12, "hello",
Expand Down

0 comments on commit 66eecdc

Please sign in to comment.