Skip to content

Commit

Permalink
test: replace reldiff with isapprox (#321)
Browse files Browse the repository at this point in the history
* test/ndarray: replace `reldiff` with `isapprox`

* test/symbolic-node: replace `reldiff` with `isapprox`

* test/operator: replace `reldiff` with `isapprox`

* test/io: replace `reldiff` with `isapprox`

* test: remove `reldiff`
  • Loading branch information
iblislin authored and pluskid committed Nov 13, 2017
1 parent f8e1938 commit a19fc93
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 109 deletions.
6 changes: 0 additions & 6 deletions test/common.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
################################################################################
# Common models used in testing
################################################################################
function reldiff(a, b)
diff = sum(abs.(a .- b))
norm = sum(abs.(a))
return diff / (norm + 1e-10)
end

function rand_dims(max_ndim=6)
tuple(rand(1:10, rand(1:max_ndim))...)
end
Expand Down
2 changes: 1 addition & 1 deletion test/unittest/bind.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module TestBind
using MXNet
using Base.Test

using ..Main: rand_dims, reldiff
using ..Main: rand_dims

################################################################################
# Test Implementations
Expand Down
14 changes: 8 additions & 6 deletions test/unittest/io.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module TestIO

using MXNet
using Base.Test

using ..Main: rand_dims, reldiff
using ..Main: rand_dims

function test_mnist()
info("IO::MNIST")
Expand Down Expand Up @@ -64,7 +65,7 @@ function test_arrays_impl(data::Vector, label::Vector, provider::mx.ArrayDataPro
data_get = mx.get_data(provider, batch)

for (d_real, d_get) in zip(data_batch, data_get)
@test reldiff(d_real, copy(d_get)[[1:n for n in size(d_real)]...]) < 1e-6
@test d_real copy(d_get)[[1:n for n in size(d_real)]...]
@test mx.count_samples(provider, batch) == size(d_real)[end]
end
end
Expand Down Expand Up @@ -97,7 +98,7 @@ function test_arrays_shuffle()

sample_count = 15
batch_size = 4
data = rand(1, sample_count)
data = rand(mx.MX_float, 1, sample_count)
label = collect(1:sample_count)
provider = mx.ArrayDataProvider(data, :index => label, batch_size=batch_size, shuffle=true)

Expand All @@ -107,14 +108,15 @@ function test_arrays_shuffle()
for (idx, batch) in zip(idx_all, provider)
data_batch = mx.get(provider, batch, :data)
label_batch = mx.get(provider, batch, :index)
ns_batch = mx.count_samples(provider, batch)
data_got[idx:idx+ns_batch-1] = copy(data_batch)[1:ns_batch]
ns_batch = mx.count_samples(provider, batch)
data_got[idx:idx+ns_batch-1] = copy(data_batch)[1:ns_batch]
label_got[idx:idx+ns_batch-1] = copy(label_batch)[1:ns_batch]
end

@test label_got != label
@test sort(label_got) == label
@test reldiff(data_got, data[:,Int[label_got...]]) < 1e-6
@test size(data_got) == size(data[:, Int[label_got...]])
@test data_got data[:, Int[label_got...]]
end

@testset "IO Test" begin
Expand Down

0 comments on commit a19fc93

Please sign in to comment.