Skip to content

Commit

Permalink
Merge pull request #4 from tiagofrepereira2012/master
Browse files Browse the repository at this point in the history
Solved issue #2
  • Loading branch information
tiagofrepereira2012 committed Jul 17, 2015
2 parents a63b9b2 + d23dd3f commit d039e50
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
14 changes: 9 additions & 5 deletions bob/ip/base/cpp/DCTFeatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,15 @@ void bob::ip::base::DCTFeatures::extract_(const blitz::Array<double,2>& src, bli
{
blitz::firstIndex ii;
blitz::secondIndex jj;
blitz::thirdIndex kk;
m_cache_dct1 = blitz::mean(dst(kk,ii,jj), kk); // mean
m_cache_dct2 = blitz::sum(blitz::pow2(dst(kk,ii,jj) - m_cache_dct1(ii,jj)),kk) / (double)(dst.extent(0));
m_cache_dct2 = blitz::where(m_cache_dct2 <= m_norm_epsilon, 1., blitz::sqrt(m_cache_dct2));
dst = (dst(ii,jj,kk) - m_cache_dct1(kk)) / m_cache_dct2(kk);

for (int kk=0; kk<dst.extent(2); ++kk) {
blitz::Array<double,2> dst_slice = dst(blitz::Range::all(), blitz::Range::all(), kk);
double mean = blitz::mean(dst_slice);
m_cache_dct1(kk) = mean;
double variance = blitz::sum(blitz::pow2(dst_slice))/(double)(dst.extent(0)*dst.extent(1)) - mean*mean;
m_cache_dct2(kk) = variance;
dst_slice = (dst_slice(ii,jj) - mean) / (variance <= m_norm_epsilon ? 1. : sqrt(variance));
}
}
}

Expand Down
16 changes: 16 additions & 0 deletions bob/ip/base/test/test_dct.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,19 @@ def test04_output_shape():
dct_op = bob.ip.base.DCTFeatures(6, (3, 4), (0, 0))
assert dct_op.output_shape(src, flat=True) == (4,6)
assert dct_op.output_shape(src, flat=False) == (2,2,6)


def test_3Doutput_and_2Doutput():
numpy.random.seed(3)
data = numpy.random.randn(10,10)

#Comparing the regular output with the output per block WITHOUT DCT coefficient normalization
o = bob.ip.base.DCTFeatures(5, (3,3),(1,1),True,False,False)
assert (numpy.allclose(o(data, False).flatten(),o(data, True).flatten(),1e-10))

#Comparing the regular output with the output per block WITH DCT coefficient normalization
o = bob.ip.base.DCTFeatures(5,(3,3),(1,1),True,True,False)
assert (numpy.allclose(o(data, False).flatten(),o(data, True).flatten(),1e-10))



0 comments on commit d039e50

Please sign in to comment.