From b38b688edfc4c2ee8bf461240f8384a70575d3a4 Mon Sep 17 00:00:00 2001 From: nilayjain Date: Tue, 14 Jun 2016 17:47:57 +0000 Subject: [PATCH 01/14] edge_boxes: feature extraction updated, tests added --- src/mlpack/methods/CMakeLists.txt | 1 + src/mlpack/methods/edge_boxes/CMakeLists.txt | 21 + .../methods/edge_boxes/edge_boxes_main.cpp | 91 ++ .../methods/edge_boxes/feature_extraction.hpp | 87 ++ .../edge_boxes/feature_extraction_impl.hpp | 960 ++++++++++++++++++ src/mlpack/tests/CMakeLists.txt | 1 + src/mlpack/tests/edge_boxes_test.cpp | 187 ++++ 7 files changed, 1348 insertions(+) create mode 100644 src/mlpack/methods/edge_boxes/CMakeLists.txt create mode 100644 src/mlpack/methods/edge_boxes/edge_boxes_main.cpp create mode 100644 src/mlpack/methods/edge_boxes/feature_extraction.hpp create mode 100644 src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp create mode 100644 src/mlpack/tests/edge_boxes_test.cpp diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt index 5734d5c9d8a..adb67489b67 100644 --- a/src/mlpack/methods/CMakeLists.txt +++ b/src/mlpack/methods/CMakeLists.txt @@ -23,6 +23,7 @@ set(DIRS decision_stump det emst + edge_boxes fastmks gmm hmm diff --git a/src/mlpack/methods/edge_boxes/CMakeLists.txt b/src/mlpack/methods/edge_boxes/CMakeLists.txt new file mode 100644 index 00000000000..ce7bdea79ab --- /dev/null +++ b/src/mlpack/methods/edge_boxes/CMakeLists.txt @@ -0,0 +1,21 @@ + +cmake_minimum_required(VERSION 2.8) + +# Define the files we need to compile. +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + feature_extraction.hpp + feature_extraction_impl.hpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) + +add_cli_executable(edge_boxes) + diff --git a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp new file mode 100644 index 00000000000..b8bb0a79a97 --- /dev/null +++ b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp @@ -0,0 +1,91 @@ +/** + * @file decision_stump.hpp + * @author + * + * Definition of decision stumps. + */ +#include +#include "feature_extraction.hpp" + +using namespace mlpack; +using namespace mlpack::structured_tree; +using namespace std; + +int main() +{ + /* + :param options: + num_images: number of images in the dataset. + rgbd: 0 for RGB, 1 for RGB + depth + shrink: amount to shrink channels + n_orient: number of orientations per gradient scale + grd_smooth_rad: radius for image gradient smoothing + grd_norm_rad: radius for gradient normalization + reg_smooth_rad: radius for reg channel smoothing + ss_smooth_rad: radius for sim channel smoothing + p_size: size of image patches + g_size: size of ground truth patches + n_cell: number of self similarity cells + + n_pos: number of positive patches per tree + n_neg: number of negative patches per tree + fraction: fraction of features to use to train each tree + n_tree: number of trees in forest to train + n_class: number of classes (clusters) for binary splits + min_count: minimum number of data points to allow split + min_child: minimum number of data points allowed at child nodes + max_depth: maximum depth of tree + split: options include 'gini', 'entropy' and 'twoing' + discretize: optional function mapping structured to class labels + + stride: stride at which to compute edges + sharpen: sharpening amount (can only decrease after training) + n_tree_eval: number of trees to evaluate per location + nms: if true apply non-maximum suppression to edges + */ + + map options; + options["num_images"] = 2; + options["row_size"] = 321; + options["col_size"] = 481; + options["rgbd"] = 0; + options["shrink"] = 2; + options["n_orient"] = 4; + options["grd_smooth_rad"] = 0; + options["grd_norm_rad"] = 4; + options["reg_smooth_rad"] = 2; + options["ss_smooth_rad"] = 8; + options["p_size"] = 32; + options["g_size"] = 16; + options["n_cell"] = 5; + + options["n_pos"] = 10000; + options["n_neg"] = 10000; + //options["fraction"] = 0.25; + options["n_tree"] = 8; + options["n_class"] = 2; + options["min_count"] = 1; + options["min_child"] = 8; + options["max_depth"] = 64; + options["split"] = 0; // we use 0 for gini, 1 for entropy, 2 for other + options["stride"] = 2; + options["sharpen"] = 2; + options["n_tree_eval"] = 4; + options["nms"] = 1; // 1 for true, 0 for false + + StructuredForests SF(options); +// arma::uvec x(2); + //SF.GetFeatureDimension(x); + + arma::mat segmentations, boundaries, images; + data::Load("/home/nilay/Desktop/GSoC/code/example/example/small_images.csv", images); + data::Load("/home/nilay/Desktop/GSoC/code/example/example/small_boundary_1.csv", boundaries); + data::Load("/home/nilay/Desktop/GSoC/code/example/example/small_segmentation_1.csv", segmentations); + + arma::mat input_data = SF.LoadData(images, boundaries, segmentations); + cout << input_data.n_rows << " " << input_data.n_cols << endl; + SF.PrepareData(input_data); + cout << "PrepareData done." << endl; + return 0; +} + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp new file mode 100644 index 00000000000..68cae5f4acb --- /dev/null +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -0,0 +1,87 @@ +/** + * @file feature_extraction.hpp + * @author Nilay Jain + * + * Feature Extraction for the edge_boxes algorithm. + */ +#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_HPP +#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_HPP +//#define INF 999999.9999 +//#define EPS 1E-20 +#include + +namespace mlpack { +namespace structured_tree { + +template +class StructuredForests +{ + + public: + + static constexpr double eps = 1e-20; + + std::map options; + + StructuredForests(const std::map inMap); + + MatType LoadData(MatType const &images, MatType const &boundaries,\ + MatType const &segmentations); + + void PrepareData(MatType const &InputData); + + arma::vec GetFeatureDimension(); + + arma::vec DistanceTransform1D(arma::vec const &f, const size_t n,\ + const double inf); + + void DistanceTransform2D(MatType &im, const double inf); + + MatType DistanceTransformImage(MatType const &im, double on); + + arma::field GetFeatures(MatType const &image, arma::umat &loc); + + CubeType CopyMakeBorder(CubeType const &InImage, size_t top, + size_t left, size_t bottom, size_t right); + + void GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch); + + CubeType RGB2LUV(CubeType const &InImage); + + MatType bilinearInterpolation(MatType const &src, + size_t height, size_t width); + + CubeType sepFilter2D(CubeType &InOutImage, arma::vec &kernel,\ + size_t radius); + + CubeType ConvTriangle(CubeType &InImage, size_t radius); + + void Gradient(CubeType const &InImage, + MatType &Magnitude, + MatType &Orientation); + + MatType MaxAndLoc(CubeType &mag, arma::umat &Location) const; + + CubeType Histogram(MatType const &Magnitude, + MatType const &Orientation, + size_t downscale, size_t interp); + + CubeType ViewAsWindows(CubeType const &channels, arma::umat const &loc); + + CubeType GetRegFtr(CubeType const &channels, arma::umat const &loc); + + CubeType GetSSFtr(CubeType const &channels, arma::umat const &loc); + + CubeType Rearrange(CubeType const &channels); + + CubeType PDist(CubeType const &features, arma::uvec const &grid_pos); + + //void Discretize(MatType const &lbl, size_t n_class, size_t n_sample); +}; + + +} //namespace structured_tree +} // namespace mlpack +#include "feature_extraction_impl.hpp" +#endif + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp new file mode 100644 index 00000000000..c7d7b87a980 --- /dev/null +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -0,0 +1,960 @@ +/** + * @file feature_extraction_impl.hpp + * @author Nilay Jain + * + * Implementation of feature extraction methods. + */ +#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP +#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP + + +#include "feature_extraction.hpp" +#include + +namespace mlpack { +namespace structured_tree { + +template +StructuredForests:: +StructuredForests(const std::map inMap) + : options(std::move(inMap)) +{ + +} + +template +MatType StructuredForests:: +LoadData(MatType const &images, MatType const &boundaries,\ + MatType const &segmentations) +{ + const size_t num_images = this->options["num_images"]; + const size_t row_size = this->options["row_size"]; + const size_t col_size = this->options["col_size"]; + MatType input_data(num_images * row_size * 5, col_size); + // we store the input data as follows: + // images (3), boundaries (1), segmentations (1). + size_t loop_iter = num_images * 5; + size_t row_idx = 0; + size_t col_i = 0, col_s = 0, col_b = 0; + for(size_t i = 0; i < loop_iter; ++i) + { + if (i % 5 == 4) + { + input_data.submat(row_idx, 0, row_idx + row_size - 1,\ + col_size - 1) = MatType(segmentations.colptr(col_s),\ + col_size, row_size).t(); + ++col_s; + } + else if (i % 5 == 3) + { + input_data.submat(row_idx, 0, row_idx + row_size - 1,\ + col_size - 1) = MatType(boundaries.colptr(col_b),\ + col_size, row_size).t(); + ++col_b; + } + else + { + input_data.submat(row_idx, 0, row_idx + row_size - 1,\ + col_size - 1) = MatType(images.colptr(col_i), + col_size, row_size).t(); + ++col_i; + } + row_idx += row_size; + } + return input_data; +} + +template +arma::vec StructuredForests:: +GetFeatureDimension() +{ + /* + shrink: amount to shrink channels + p_size: size of image patches + n_cell: number of self similarity cells + n_orient: number of orientations per gradient scale + */ + arma::vec FtrDim(2); + + const size_t shrink = this->options["shrink"]; + const size_t p_size = this->options["p_size"]; + const size_t n_cell = this->options["n_cell"]; + const size_t rgbd = this->options["rgbd"]; + const size_t n_orient = this->options["n_orient"]; + /* + n_color_ch: number of color channels + n_grad_ch: number of gradient channels + n_ch: total number of channels + */ + size_t n_color_ch; + if (this->options["rgbd"] == 0) + n_color_ch = 3; + else + n_color_ch = 4; + + const size_t n_grad_ch = 2 * (1 + n_orient); + + const size_t n_ch = n_color_ch + n_grad_ch; + FtrDim[0] = std::pow((p_size / shrink) , 2) * n_ch; + FtrDim[1] = std::pow(n_cell , 2) * (std::pow (n_cell, 2) - 1) / 2 * n_ch; + return FtrDim; +} + +template +arma::vec StructuredForests:: +DistanceTransform1D(arma::vec const &f, const size_t n, const double inf) +{ + arma::vec d(n), v(n), z(n + 1); + size_t k = 0; + v[0] = 0.0; + z[0] = -inf; + z[1] = inf; + for (size_t q = 1; q <= n - 1; ++q) + { + float s = ( (f[q] + q * q)-( f[v[k]] + v[k] * v[k]) ) / (2 * q - 2 * v[k]); + while (s <= z[k]) + { + --k; + s = ( (f[q] + q * q) - (f[v[k]] + v[k] * v[k]) ) / (2 * q - 2 * v[k]); + } + + ++k; + v[k] = static_cast(q); + z[k] = s; + z[k+1] = inf; + } + + k = 0; + for (size_t q = 0; q <= n-1; q++) + { + while (z[k+1] < q) + ++k; + d[q] = (q - v[k]) * (q - v[k]) + f[v[k]]; + } + return d; +} + +template +void StructuredForests:: +DistanceTransform2D(MatType &im, const double inf) +{ + arma::vec f(std::max(im.n_rows, im.n_cols)); + // transform along columns + for (size_t x = 0; x < im.n_cols; ++x) + { + f.subvec(0, im.n_rows - 1) = im.col(x); + arma::vec d = this->DistanceTransform1D(f, im.n_rows, inf); + im.col(x) = d; + } + + // transform along rows + for (size_t y = 0; y < im.n_rows; y++) + { + f.subvec(0, im.n_cols - 1) = im.row(y).t(); + arma::vec d = this->DistanceTransform1D(f, im.n_cols, inf); + im.row(y) = d.t(); + } +} + +/* euclidean distance transform of binary image using squared distance */ +template +MatType StructuredForests:: +DistanceTransformImage(MatType const &im, double on) +{ + //need a large value but not infinity. + double inf = 999999.99; + MatType out = MatType(im.n_rows, im.n_cols, arma::fill::zeros); + out.elem( find(im != on) ).fill(inf); + this->DistanceTransform2D(out, inf); + return out; +} + +template +CubeType StructuredForests:: +CopyMakeBorder(CubeType const &InImage, size_t top, + size_t left, size_t bottom, size_t right) +{ + CubeType OutImage(InImage.n_rows + top + bottom, InImage.n_cols + left + right, InImage.n_slices); + + for(size_t i = 0; i < InImage.n_slices; ++i) + { + OutImage.slice(i).submat(top, left, InImage.n_rows + top - 1, InImage.n_cols + left - 1) + = InImage.slice(i); + + for(size_t j = 0; j < right; ++j) + { + OutImage.slice(i).col(InImage.n_cols + left + j).subvec(top, InImage.n_rows + top - 1) + = InImage.slice(i).col(InImage.n_cols - j - 1); + } + + for(size_t j = 0; j < left; ++j) + { + OutImage.slice(i).col(j).subvec(top, InImage.n_rows + top - 1) + = InImage.slice(i).col(left - 1 - j); + } + + for(size_t j = 0; j < top; j++) + { + + OutImage.slice(i).row(j) + = OutImage.slice(i).row(2 * top - 1 - j); + } + + for(size_t j = 0; j < bottom; j++) + { + OutImage.slice(i).row(InImage.n_rows + top + j) + = OutImage.slice(i).row(InImage.n_rows + top - j - 1); + } + + } + return OutImage; +} + +template +CubeType StructuredForests:: +RGB2LUV(CubeType const &InImage) +{ + //assert type is double or float. + double a, y0, maxi; + a = std::pow(29.0, 3) / 27.0; + y0 = 8.0 / a; + maxi = 1.0 / 270.0; + + arma::vec table(1064); + for (size_t i = 0; i <= 1024; ++i) + { + table(i) = i / 1024.0; + + if (table(i) > y0) + table(i) = 116 * pow(table(i), 1.0/3.0) - 16.0; + else + table(i) = table(i) * a; + + table(i) = table(i) * maxi; + } + for(size_t i = 1025; i < table.n_elem; ++i) + { + table(i) = table(i - 1); + } + + MatType rgb2xyz; + rgb2xyz << 0.430574 << 0.222015 << 0.020183 << arma::endr + << 0.341550 << 0.706655 << 0.129553 << arma::endr + << 0.178325 << 0.071330 << 0.939180; + + //see how to calculate this efficiently. numpy.dot does this. + CubeType xyz(InImage.n_rows, InImage.n_cols, rgb2xyz.n_cols); + + for (size_t i = 0; i < InImage.slice(0).n_elem; ++i) + { + double r = InImage.slice(0)(i); + double g = InImage.slice(1)(i); + double b = InImage.slice(2)(i); + + xyz.slice(0)(i) = 0.430574 * r + 0.341550 * g + 0.178325 * b; + xyz.slice(1)(i) = 0.222015 * r + 0.706655 * g + 0.071330 * b; + xyz.slice(2)(i) = 0.020183 * r + 0.129553 * g + 0.939180 * b; + + /* + xyz.slice(0)(i) = 0.430574 * r + 0.341550 * g + 0.178325 * b; + xyz.slice(1)(i) = 0.222015 * r + 0.706655 * g + 0.129553 * b; + xyz.slice(2)(i) = 0.020183 * r + 0.071330 * g + 0.939180 * b; + */ + } + + MatType nz(InImage.n_rows, InImage.n_cols); + + nz = 1.0 / ( xyz.slice(0) + (15 * xyz.slice(1) ) + + (3 * xyz.slice(2) + 1e-35)); + CubeType OutImage(InImage.n_rows, InImage.n_cols, InImage.n_slices); + + for(size_t j = 0; j < xyz.n_cols; ++j) + { + for(size_t i = 0; i < xyz.n_rows; ++i) + { + OutImage(i, j, 0) = table( static_cast( (1024 * xyz(i, j, 1) ) ) ); + } + } + + OutImage.slice(1) = OutImage.slice(0) % (13 * 4 * (xyz.slice(0) % nz) \ + - 13 * 0.197833) + 88 * maxi; + OutImage.slice(2) = OutImage.slice(0) % (13 * 9 * (xyz.slice(1) % nz) \ + - 13 * 0.468331) + 134 * maxi; + + return OutImage; +} + +/*implement this function in a column major order.*/ +template +MatType StructuredForests:: +bilinearInterpolation(MatType const &src, + size_t height, size_t width) +{ + MatType dst(height, width); + double const x_ratio = static_cast((src.n_cols - 1)) / width; + double const y_ratio = static_cast((src.n_rows - 1)) / height; + for(size_t row = 0; row != dst.n_rows; ++row) + { + size_t y = static_cast(row * y_ratio); + double const y_diff = (row * y_ratio) - y; //distance of the nearest pixel(y axis) + double const y_diff_2 = 1 - y_diff; + for(size_t col = 0; col != dst.n_cols; ++col) + { + size_t x = static_cast(col * x_ratio); + double const x_diff = (col * x_ratio) - x; //distance of the nearet pixel(x axis) + double const x_diff_2 = 1 - x_diff; + double const y2_cross_x2 = y_diff_2 * x_diff_2; + double const y2_cross_x = y_diff_2 * x_diff; + double const y_cross_x2 = y_diff * x_diff_2; + double const y_cross_x = y_diff * x_diff; + dst(row, col) = y2_cross_x2 * src(y, x) + + y2_cross_x * src(y, x + 1) + + y_cross_x2 * src(y + 1, x) + + y_cross_x * src(y + 1, x + 1); + } + } + + return dst; +} + +template +CubeType StructuredForests:: +sepFilter2D(CubeType &InOutImage, arma::vec &kernel, size_t radius) +{ + CubeType OutImage = this->CopyMakeBorder(InOutImage, radius, radius, radius, radius); + + arma::vec row_res, col_res; + // reverse InOutImage and OutImage to avoid making an extra matrix. + // InImage is renamed to InOutImage in this function for this reason only. + arma::mat k_mat = kernel * kernel.t(); + for(size_t k = 0; k < OutImage.n_slices; ++k) + { + for(size_t j = radius; j < OutImage.n_cols - radius; ++j) + { + for(size_t i = radius; i < OutImage.n_rows - radius; ++i) + { + InOutImage(i - radius, j - radius, k) = + arma::accu(OutImage.slice(k).submat(i - radius, j - radius, i + radius, j + radius) % k_mat); + } + } + } + + return InOutImage; +} + +template +CubeType StructuredForests:: +ConvTriangle(CubeType &InImage, size_t radius) +{ + if (radius == 0) + { + return InImage; + } + else if (radius <= 1) + { + const double p = 12.0 / radius / (radius + 2) - 2; + arma::vec kernel = {1 , p, 1}; + kernel /= (p + 2); + + return this->sepFilter2D(InImage, kernel, radius); + } + else + { + const size_t len = 2 * radius + 1; + arma::vec kernel(len); + for( size_t i = 0; i < radius; ++i) + kernel(i) = i + 1; + + kernel(radius) = radius + 1; + + size_t r = radius; + for( size_t i = radius + 1; i < len; ++i) + kernel(i) = r--; + + kernel /= std::pow(radius + 1, 2); + return this->sepFilter2D(InImage, kernel, radius); + } +} + +//just a helper function, can't use it for anything else +//finds max numbers on cube axis and returns max values, +// also stores the locations of max values in Location +template +MatType StructuredForests:: +MaxAndLoc(CubeType &mag, arma::umat &Location) const +{ + /*Vectorize this function after prototype works*/ + MatType MaxVal(Location.n_rows, Location.n_cols); + for(size_t i = 0; i < mag.n_rows; ++i) + { + for(size_t j = 0; j < mag.n_cols; ++j) + { + /*can use -infinity here*/ + double max = std::numeric_limits::min(); + for(size_t k = 0; k < mag.n_slices; ++k) + { + if(mag(i, j, k) > max) + { + max = mag(i, j, k); + MaxVal(i, j) = max; + Location(i, j) = k; + } + } + } + } + return MaxVal; +} + +template +void StructuredForests:: +Gradient(CubeType const &InImage, + MatType &Magnitude, + MatType &Orientation) +{ + const size_t grd_norm_rad = this->options["grd_norm_rad"]; + CubeType dx(InImage.n_rows, InImage.n_cols, InImage.n_slices), + dy(InImage.n_rows, InImage.n_cols, InImage.n_slices); + + dx.zeros(); + dy.zeros(); + + /* + From MATLAB documentation: + [FX,FY] = gradient(F), where F is a matrix, returns the + x and y components of the two-dimensional numerical gradient. + FX corresponds to ∂F/∂x, the differences in x (horizontal) direction. + FY corresponds to ∂F/∂y, the differences in the y (vertical) direction. + */ + + + /* + gradient calculates the central difference for interior data points. + For example, consider a matrix with unit-spaced data, A, that has + horizontal gradient G = gradient(A). The interior gradient values, G(:,j), are: + + G(:,j) = 0.5*(A(:,j+1) - A(:,j-1)); + where j varies between 2 and N-1, where N is size(A,2). + + The gradient values along the edges of the matrix are calculated with single-sided differences, so that + + G(:,1) = A(:,2) - A(:,1); + G(:,N) = A(:,N) - A(:,N-1); + + The spacing between points in each direction is assumed to be one. + */ + for (size_t i = 0; i < InImage.n_slices; ++i) + { + dx.slice(i).col(0) = InImage.slice(i).col(1) - InImage.slice(i).col(0); + dx.slice(i).col(InImage.n_cols - 1) = InImage.slice(i).col(InImage.n_cols - 1) + - InImage.slice(i).col(InImage.n_cols - 2); + + for (size_t j = 1; j < InImage.n_cols-1; j++) + dx.slice(i).col(j) = 0.5 * ( InImage.slice(i).col(j+1) - InImage.slice(i).col(j) ); + + // do same for dy. + dy.slice(i).row(0) = InImage.slice(i).row(1) - InImage.slice(i).row(0); + dy.slice(i).row(InImage.n_rows - 1) = InImage.slice(i).row(InImage.n_rows - 1) + - InImage.slice(i).row(InImage.n_rows - 2); + + for (size_t j = 1; j < InImage.n_rows-1; j++) + dy.slice(i).row(j) = 0.5 * ( InImage.slice(i).row(j+1) - InImage.slice(i).row(j) ); + } + + CubeType mag(InImage.n_rows, InImage.n_cols, InImage.n_slices); + for (size_t i = 0; i < InImage.n_slices; ++i) + { + mag.slice(i) = arma::sqrt( arma::square \ + ( dx.slice(i) + arma::square( dy.slice(i) ) ) ); + } + + arma::umat Location(InImage.n_rows, InImage.n_cols); + Magnitude = this->MaxAndLoc(mag, Location); + if(grd_norm_rad != 0) + { + //we have to do this ugly thing, or override ConvTriangle + // and sepFilter2D methods. + CubeType mag2(InImage.n_rows, InImage.n_cols, 1); + mag2.slice(0) = Magnitude; + mag2 = this->ConvTriangle(mag2, grd_norm_rad); + Magnitude = Magnitude / (mag2.slice(0) + 0.01); + } + MatType dx_mat(dx.n_rows, dx.n_cols),\ + dy_mat(dy.n_rows, dy.n_cols); + + for(size_t j = 0; j < InImage.n_cols; ++j) + { + for(size_t i = 0; i < InImage.n_rows; ++i) + { + dx_mat(i, j) = dx(i, j, Location(i, j)); + dy_mat(i, j) = dy(i, j, Location(i, j)); + } + } + Orientation = arma::atan(dy_mat / dx_mat); + Orientation.transform( [](double val) { if(val < 0) return (val + arma::datum::pi); else return (val);} ); + + for(size_t j = 0; j < InImage.n_cols; ++j) + { + for(size_t i = 0; i < InImage.n_rows; ++i) + { + if( abs(dx_mat(i, j)) + abs(dy_mat(i, j)) < 1E-5) + Orientation(i, j) = 0.5 * arma::datum::pi; + } + } +} + +template +CubeType StructuredForests:: +Histogram(MatType const &Magnitude, + MatType const &Orientation, + size_t downscale, size_t interp) +{ + //i don't think this function can be vectorized. + + //n_orient: number of orientations per gradient scale + const size_t n_orient = this->options["n_orient"]; + //size of HistArr: n_rbin * n_cbin * n_orient . . . (create in caller...) + const size_t n_rbin = (Magnitude.n_rows + downscale - 1) / downscale; + const size_t n_cbin = (Magnitude.n_cols + downscale - 1) / downscale; + double o_range, o; + o_range = arma::datum::pi / n_orient; + + CubeType HistArr(n_rbin, n_cbin, n_orient); + HistArr.zeros(); + + size_t r, c, o1, o2; + for(size_t i = 0; i < Magnitude.n_rows; ++i) + { + for(size_t j = 0; j < Magnitude.n_cols; ++j) + { + r = i / downscale; + c = j / downscale; + + if( interp != 0) + { + o = Orientation(i, j) / o_range; + o1 = ((size_t) o) % n_orient; + o2 = (o1 + 1) % n_orient; + HistArr(r, c, o1) += Magnitude(i, j) * (1 + (int)o - o); + HistArr(r, c, o2) += Magnitude(i, j) * (o - (int) o); + } + else + { + o1 = (size_t) (Orientation(i, j) / o_range + 0.5) % n_orient; + HistArr(r, c, o1) += Magnitude(i, j); + } + } + } + + HistArr = HistArr / downscale; + + for (size_t i = 0; i < HistArr.n_slices; ++i) + HistArr.slice(i) = arma::square(HistArr.slice(i)); + + return HistArr; +} + +template +void StructuredForests:: +GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch) +{ + CubeType luv = this->RGB2LUV(InImage); + + const size_t shrink = this->options["shrink"]; + const size_t n_orient = this->options["n_orient"]; + const size_t grd_smooth_rad = this->options["grd_smooth_rad"]; + const size_t grd_norm_rad = this->options["grd_norm_rad"]; + const size_t num_channels = 13; + const size_t rsize = luv.n_rows / shrink; + const size_t csize = luv.n_cols / shrink; + CubeType channels(rsize, csize, num_channels); + + + size_t slice_idx = 0; + + for( slice_idx = 0; slice_idx < luv.n_slices; ++slice_idx) + channels.slice(slice_idx) + = this->bilinearInterpolation(luv.slice(slice_idx), (size_t)rsize, (size_t)csize); + + double scale = 0.5; + + while(scale <= 1.0) + { + CubeType img( (luv.n_rows * scale), + (luv.n_cols * scale), + luv.n_slices ); + + for( slice_idx = 0; slice_idx < luv.n_slices; ++slice_idx) + { + img.slice(slice_idx) = + this->bilinearInterpolation(luv.slice(slice_idx), + (luv.n_rows * scale), + (luv.n_cols * scale) ); + } + + CubeType OutImage = this->ConvTriangle(img, grd_smooth_rad); + + MatType Magnitude(InImage.n_rows, InImage.n_cols), + Orientation(InImage.n_rows, InImage.n_cols); + + this->Gradient(OutImage, Magnitude, Orientation); + + size_t downscale = std::max(1, (int)(shrink * scale)); + + CubeType Hist = this->Histogram(Magnitude, Orientation, + downscale, 0); + + channels.slice(slice_idx) = + bilinearInterpolation( Magnitude, rsize, csize); + slice_idx++; + for(size_t i = 0; i < InImage.n_slices; ++i) + channels.slice(i + slice_idx) = + bilinearInterpolation( Magnitude, rsize, csize); + slice_idx += 3; + scale += 0.5; + } + + //cout << "size of channels: " << arma::size(channels) << endl; + double reg_smooth_rad, ss_smooth_rad; + reg_smooth_rad = this->options["reg_smooth_rad"] / (double) shrink; + ss_smooth_rad = this->options["ss_smooth_rad"] / (double) shrink; + + + + + if (reg_smooth_rad > 1.0) + reg_ch = this->ConvTriangle(channels, (size_t) (std::round(reg_smooth_rad)) ); + else + reg_ch = this->ConvTriangle(channels, reg_smooth_rad); + + if (ss_smooth_rad > 1.0) + ss_ch = this->ConvTriangle(channels, (size_t) (std::round(ss_smooth_rad)) ); + else + ss_ch = this->ConvTriangle(channels, ss_smooth_rad); + +} + +template +CubeType StructuredForests:: +ViewAsWindows(CubeType const &channels, arma::umat const &loc) +{ + // 500 for pos_loc, and 500 for neg_loc. + // channels = 160, 240, 13. + CubeType features = CubeType(16, 16, 1000 * 13); + const size_t patchSize = 16; + const size_t p = patchSize / 2; + //increase the channel boundary to protect error against image boundaries. + CubeType inc_ch = this->CopyMakeBorder(channels, p, p, p, p); + for (size_t i = 0, channel = 0; i < loc.n_rows; ++i) + { + size_t x = loc(i, 0); + size_t y = loc(i, 1); + + /*(x,y) in channels, is ((x+p), (y+p)) in inc_ch*/ + //cout << "(x,y) = " << x << " " << y << endl; + CubeType patch = inc_ch.tube((x + p) - p, (y + p) - p,\ + (x + p) + p - 1, (y + p) + p - 1); + // since each patch has 13 channel we have to increase the index by 13 + + //cout <<"patch size = " << arma::size(patch) << endl; + + features.slices(channel, channel + 12) = patch; + //cout << "sahi hai " << endl; + channel += 13; + + } + //cout << "successfully returned. . ." << endl; + return features; +} + +template +CubeType StructuredForests:: +Rearrange(CubeType const &channels) +{ + //we do (16,16,13*1000) to 256, 1000, 13, in vectorized code. + CubeType ch = CubeType(256, 1000, 13); + for(size_t i = 0; i < 1000; i++) + { + //MatType m(256, 13); + for(size_t j = 0; j < 13; ++j) + { + size_t sl = (i * j) / 1000; + //cout << "(i,j) = " << i << ", " << j << endl; + ch.slice(sl).col(i) = arma::vectorise(channels.slice(i * j)); + } + } + return ch; +} + +// returns 256 * 1000 * 13 dimension features. +template +CubeType StructuredForests:: +GetRegFtr(CubeType const &channels, arma::umat const &loc) +{ + int shrink = this->options["shrink"]; + int p_size = this->options["p_size"] / shrink; + CubeType wind = this->ViewAsWindows(channels, loc); + return this->Rearrange(wind); +} + +template +CubeType StructuredForests:: +PDist(CubeType const &features, arma::uvec const &grid_pos) +{ + // size of DestArr: + // InImage.n_rows * (InImage.n_rows - 1)/2 * InImage.n_slices + //find nC2 differences, for locations in the grid_pos. + //python: input: (716, 256, 13) --->(716, 25, 13) ; output: (716, 300, 13). + //input features : 256,1000,13; output: 300, 1000, 13 + + CubeType output(300, 1000, 13); + for(size_t k = 0; k < features.n_slices; ++k) + { + size_t r_idx = 0; + for(size_t i = 0; i < grid_pos.n_elem; ++i) //loop length : 25 + { + for(size_t j = i + 1; j < grid_pos.n_elem; ++j) //loop length : 25 + { + output.slice(k).row(r_idx) = features.slice(k).row(grid_pos(i)) + - features.slice(k).row(grid_pos(j)); + ++r_idx; + } + } + } + return output; +} + +//returns 300,1000,13 dimension features. +template +CubeType StructuredForests:: +GetSSFtr(CubeType const &channels, arma::umat const &loc) +{ + const size_t shrink = this->options["shrink"]; + const size_t p_size = this->options["p_size"] / shrink; + + //n_cell: number of self similarity cells + const size_t n_cell = this->options["n_cell"]; + const size_t half_cell_size = (size_t) round(p_size / (2.0 * n_cell)); + + arma::uvec g_pos(n_cell); + for(size_t i = 0; i < n_cell; ++i) + { + g_pos(i) = (size_t)round( (i + 1) * (p_size + 2 * half_cell_size \ + - 1) / (n_cell + 1.0) - half_cell_size); + } + arma::uvec grid_pos(n_cell * n_cell); + size_t k = 0; + for(size_t i = 0; i < n_cell; ++i) + { + for(size_t j = 0; j < n_cell; ++j) + { + grid_pos(k) = g_pos(i) * p_size + g_pos(j); + ++k; + } + } + + CubeType wind = this->ViewAsWindows(channels, loc); + CubeType re_wind = this->Rearrange(wind); + + return this->PDist(re_wind, grid_pos); +} + +template +arma::field StructuredForests:: +GetFeatures(MatType const &image, arma::umat &loc) +{ + const size_t row_size = this->options["row_size"]; + const size_t col_size = this->options["col_size"]; + const size_t bottom = (4 - (image.n_rows / 3) % 4) % 4; + const size_t right = (4 - image.n_cols % 4) % 4; + //cout << "Botttom = " << bottom << " right = " << right << endl; + + CubeType InImage(image.n_rows / 3, image.n_cols, 3); + + for(size_t i = 0; i < 3; ++i) + { + InImage.slice(i) = image.submat(i * row_size, 0, \ + (i + 1) * row_size - 1, col_size - 1); + } + + CubeType OutImage = this->CopyMakeBorder(InImage, 0, 0, bottom, right); + + const size_t num_channels = 13; + const size_t shrink = this->options["shrink"]; + const size_t rsize = OutImage.n_rows / shrink; + const size_t csize = OutImage.n_cols / shrink; + + /* this part gives double free or corruption out error + when executed for a second time */ + CubeType reg_ch = CubeType(rsize, csize, num_channels); + CubeType ss_ch = CubeType(rsize, csize, num_channels); + this->GetShrunkChannels(InImage, reg_ch, ss_ch); + + loc /= shrink; + + CubeType reg_ftr = this->GetRegFtr(reg_ch, loc); + CubeType ss_ftr = this->GetSSFtr(ss_ch, loc); + arma::field F(2,1); + F(0,0) = reg_ftr; + F(1,0) = ss_ftr; + return F; +} + +template +void StructuredForests:: +PrepareData(MatType const &InputData) +{ + const size_t num_images = this->options["num_images"]; + const size_t n_tree = this->options["n_tree"]; + const size_t n_pos = this->options["n_pos"]; + const size_t n_neg = this->options["n_neg"]; + const double fraction = 0.25; + const size_t p_size = this->options["p_size"]; + const size_t g_size = this->options["g_size"]; + const size_t shrink = this->options["shrink"]; + const size_t row_size = this->options["row_size"]; + const size_t col_size = this->options["col_size"]; + // p_rad = radius of image patches. + // g_rad = radius of ground truth patches. + const size_t p_rad = p_size / 2, g_rad = g_size / 2; + + arma::vec FtrDim = this->GetFeatureDimension(); + const size_t n_ftr_dim = FtrDim(0) + FtrDim(1); + const size_t n_smp_ftr_dim = (size_t)(n_ftr_dim * fraction); + + for(size_t i = 0; i < n_tree; ++i) + { + //implement the logic for if data already exists. + MatType ftrs = arma::zeros(n_pos + n_neg, n_smp_ftr_dim); + + //effectively a 3d array. . . + MatType lbls = arma::zeros( g_size * g_size, (n_pos + n_neg )); + // still to be done: store features and labels calculated + // in the loop and store it in these Matrices. + // Could use some suggestions for this. + + size_t loop_iter = num_images * 5; + for(size_t j = 0; j < loop_iter; j += 5) + { + MatType img, bnds, segs; + img = InputData.submat(j * row_size, 0, (j + 3) * row_size - 1, col_size - 1); + bnds = InputData.submat( (j + 3) * row_size, 0, \ + (j + 4) * row_size - 1, col_size - 1 ); + segs = InputData.submat( (j + 4) * row_size, 0, \ + (j + 5) * row_size - 1, col_size - 1 ); + + MatType mask = arma::zeros(row_size, col_size); + for(size_t b = 0; b < mask.n_cols; b = b + shrink) + for(size_t a = 0; a < mask.n_rows; a = a + shrink) + mask(a, b) = 1; + mask.col(p_rad - 1).fill(0); + mask.row( (mask.n_rows - 1) - (p_rad - 1) ).fill(0); + mask.submat(0, 0, mask.n_rows - 1, p_rad - 1).fill(0); + mask.submat(0, mask.n_cols - p_rad, mask.n_rows - 1, + mask.n_cols - 1).fill(0); + + // number of positive or negative patches per ground truth. + //int n_patches_per_gt = (int) (ceil( (float)n_pos / num_images )); + const size_t n_patches_per_gt = 500; + //cout << "n_patches_per_gt = " << n_patches_per_gt << endl; + MatType dis = arma::sqrt( this->DistanceTransformImage(bnds, 1) ); + MatType dis2 = dis; + //dis.transform( [](double val, const int& g_rad) { return (double)(val < g_rad); } ); + //dis2.transform( [](double val, const int& g_rad) { return (double)(val >= g_rad); } ); + //dis.elem( arma::find(dis >= g_rad) ).zeros(); + //dis2.elem( arma::find(dis < g_rad) ).zeros(); + + + arma::uvec pos_loc = arma::find( (dis < g_rad) % mask ); + arma::uvec neg_loc = arma::find( (dis >= g_rad) % mask ); + + pos_loc = arma::shuffle(pos_loc); + neg_loc = arma::shuffle(neg_loc); + + arma::umat loc(n_patches_per_gt * 2, 2); + //cout << "pos_loc size: " << arma::size(pos_loc) << " neg_loc size: " << arma::size(neg_loc) << endl; + //cout << "n_patches_per_gt = " << n_patches_per_gt << endl; + for(size_t i = 0; i < n_patches_per_gt; ++i) + { + loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), pos_loc(i) ).t(); + //cout << "pos_loc: " << loc(i, 0) << ", " << loc(i, 1) << endl; + } + + for(size_t i = n_patches_per_gt; i < 2 * n_patches_per_gt; ++i) + { + loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), neg_loc(i) ).t(); + //cout << "neg_loc: " << loc(i, 0) << ", " << loc(i, 1) << endl; + } + + // cout << "num patches = " << n_patches_per_gt << " num elements + = " << pos_loc.n_elem\ + // << " num elements - = " << neg_loc.n_elem << " dis.size " << dis.n_elem << endl; + + //Field F contains reg_ftr and ss_ftr for one image. + arma::field F = this->GetFeatures(img, loc); + //randomly sample 70 values each from reg_ftr and ss_ftr. + /* + CubeType ftr(140, 1000, 13); + arma::uvec r = (0, 255, 256); + arma::uvec s = (0, 299, 300); + arma::uvec rs = r.shuffle(); + arma::uvec ss = s.shuffle(); + */ + MatType lbl(g_size * g_size, 1000); + CubeType s(segs.n_rows, segs.n_cols, 1); + + // have to do this or we can overload the CopyMakeBorder to support MatType. + s.slice(0) = segs; + CubeType in_segs = this->CopyMakeBorder(s, g_rad, + g_rad, g_rad, g_rad); + for(size_t i = 0; i < loc.n_rows; ++i) + { + size_t x = loc(i, 0); size_t y = loc(i, 1); + //cout << "x, y = " << x << " " << y << endl; + lbl.col(i) = arma::vectorise(in_segs.slice(0)\ + .submat((x + g_rad) - g_rad, (y + g_rad) - g_rad,\ + (x + g_rad) + g_rad - 1, (y + g_rad) + g_rad - 1)); + } + } + } +} +/* +template +void StructuredForests:: +Discretize(MatType const &labels, size_t n_class, size_t n_sample) +{ + // Map labels to discrete class labels. + // lbls : 256 * 20000. + // n_sample: number of samples for clustering structured labels 256 + + // see the return type. + arma::uvec lis1(n_sample); + + MatType zs(n_sample, lbls.n_cols); + for (size_t i = 0; i < lis1.n_elem; ++i) + lis1(i) = i; + MatType DiscreteLabels = arma::zeros(n_sample, n); + + for (size_t i = 0; i < labels.n_cols; ++i) + { + arma::uvec z1 = lis1.shuffle(); + arma::uvec z2 = lis2.shuffle(); + for (size_t j = 0; j < zs.n_rows; ++i) + zs(i, j) = (labels(i, z1(j)) == labels(i, z2(j))) ? 1 : 0; + } + zs -= arma::mean(zs, 1); // calculate mean about cols. n_col = 256. + if ( arma::find(zs).n_elem == 0 ) + { + labels.fill(ones); + } + else + { + //find most representative segs + } + // discretize zs by discretizing pca dimensions + size_t d = min(5, n_sample, (size_t)floor(log(n_class, 2))); + zs = pca(); + +}*/ +} // namespace structured_tree +} // namespace mlpack +#endif + diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 8b36a941c97..abf91d59de9 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -17,6 +17,7 @@ add_executable(mlpack_test det_test.cpp distribution_test.cpp emst_test.cpp + edge_boxes_test.cpp fastmks_test.cpp feedforward_network_test.cpp gmm_test.cpp diff --git a/src/mlpack/tests/edge_boxes_test.cpp b/src/mlpack/tests/edge_boxes_test.cpp new file mode 100644 index 00000000000..0d699a0034b --- /dev/null +++ b/src/mlpack/tests/edge_boxes_test.cpp @@ -0,0 +1,187 @@ +/** + * @file edge_boxes_test.cpp + * @author Nilay Jain + * + * Tests for functions in edge_boxes algorithm. + */ + +#include +#include +#include + +#include +#include "old_boost_test_definitions.hpp" + +using namespace mlpack; +using namespace mlpack::structured_tree; +/* + + + //void GetShrunkChannels(CubeType& InImage, CubeType& reg_ch, CubeType& ss_ch); + + CubeType RGB2LUV(CubeType& InImage); + + void Gradient(CubeType& InImage, + MatType& Magnitude, + MatType& Orientation); + + + CubeType Histogram(MatType& Magnitude, + MatType& Orientation, + int downscale, int interp); + + /* + CubeType ViewAsWindows(CubeType& channels, arma::umat& loc); + + CubeType GetRegFtr(CubeType& channels, arma::umat& loc); + + CubeType GetSSFtr(CubeType& channels, arma::umat& loc); + + CubeType Rearrange(CubeType& channels); + + CubeType PDist(CubeType& features, arma::uvec& grid_pos); + ***-/ +*/ + +BOOST_AUTO_TEST_SUITE(EdgeBoxesTest); + +/** + * This test checks the feature extraction functions + * mentioned in feature_extraction.hpp + */ + +void Test(arma::mat m1, arma::mat m2) +{ + for (size_t i = 0; i < m1.n_cols; ++i) + BOOST_REQUIRE_CLOSE(m1(i), m2(i), 1e-2); +} + +void Test(arma::cube m1, arma::cube m2) +{ + for (size_t i = 0; i < m1.n_cols; ++i) + BOOST_REQUIRE_CLOSE(m1(i), m2(i), 1e-2); +} + +void DistanceTransformTest(arma::mat& input, + double on, arma::mat& output, + StructuredForests& SF) +{ + arma::mat dt_output = SF.DistanceTransformImage(input, on); + Test(dt_output, output); +} + +void CopyMakeBorderTest(arma::cube& input, + arma::cube& output, + StructuredForests& SF) +{ + arma::cube border_output = SF.CopyMakeBorder(input, 1, 1, 1, 1); + Test(border_output, output); +} + +void RGB2LUVTest(arma::cube& input, arma::cube& output, + StructuredForests& SF) +{ + arma::cube luv = SF.RGB2LUV(input); + Test(luv, output); +} + +void ConvTriangleTest(arma::cube& input, int radius, + arma::cube& output, StructuredForests& SF) +{ + arma::cube conv_out = SF.ConvTriangle(input, radius); + Test(conv_out, output); +} + +BOOST_AUTO_TEST_CASE(FeatureExtractionTest) +{ + std::map options; + options["num_images"] = 2; + options["row_size"] = 321; + options["col_size"] = 481; + options["rgbd"] = 0; + options["shrink"] = 2; + options["n_orient"] = 4; + options["grd_smooth_rad"] = 0; + options["grd_norm_rad"] = 4; + options["reg_smooth_rad"] = 2; + options["ss_smooth_rad"] = 8; + options["p_size"] = 32; + options["g_size"] = 16; + options["n_cell"] = 5; + + options["n_pos"] = 10000; + options["n_neg"] = 10000; + options["n_tree"] = 8; + options["n_class"] = 2; + options["min_count"] = 1; + options["min_child"] = 8; + options["max_depth"] = 64; + options["split"] = 0; // we use 0 for gini, 1 for entropy, 2 for other + options["stride"] = 2; + options["sharpen"] = 2; + options["n_tree_eval"] = 4; + options["nms"] = 1; // 1 for true, 0 for false + + arma::mat input, output; + input << 0 << 0 << 0 << arma::endr + << 0 << 1 << 0 << arma::endr + << 1 << 0 << 0; + + output << 2 << 1 << 2 << arma::endr + << 1 << 0 << 1 << arma::endr + << 0 << 1 << 2; + StructuredForests SF(options); + DistanceTransformTest(input, 1, output, SF); + + arma::cube in1(input.n_rows, input.n_cols, 1); + arma::cube c1(input.n_rows, input.n_cols, 1); + + in1.slice(0) = output; + + arma::mat out_border; + out_border << 2 << 2 << 1 << 2 << 2 << arma::endr + << 2 << 2 << 1 << 2 << 2 << arma::endr + << 1 << 1 << 0 << 1 << 1 << arma::endr + << 0 << 0 << 1 << 2 << 2 << arma::endr + << 0 << 0 << 1 << 2 << 2; + arma::cube out_b(out_border.n_rows, out_border.n_cols, 1); + out_b.slice(0) = out_border; + CopyMakeBorderTest(in1, out_b, SF); + + arma::mat out_conv; + + out_conv << 1.20987 << 1.25925 << 1.30864 << arma::endr + << 0.96296 << 1.11111 << 1.25925 << arma::endr + << 0.71604 << 0.96296 << 1.20987; + + + c1.slice(0) = out_conv; + + ConvTriangleTest(in1, 2, c1, SF); + + +arma::cube out_luv(3, 3, 3); +out_luv.slice(0) << 0.191662 << 0.139897 << 0.191662 << arma::endr + << 0.139897 << 0.0 << 0.139897 << arma::endr + << 0.0 << 0.139897 << 0.191662; + +out_luv.slice(1) << 0.325926 << 0.325926 << 0.325926 << arma::endr + << 0.325926 << 0.325926 << 0.325926 << arma::endr + << 0.325926 << 0.325926 << 0.325926; + +out_luv.slice(2) << 0.496295 << 0.496295 << 0.496295 << arma::endr + << 0.496295 << 0.496295 << 0.496295 << arma::endr + << 0.496295 << 0.496295 << 0.496295; + + + arma::cube in_luv(output.n_rows, output.n_cols, 3); + for(size_t i = 0; i < in_luv.n_slices; ++i) + { + in_luv.slice(i) = output / 10; + } + + RGB2LUVTest(in_luv, out_luv, SF); +} + +BOOST_AUTO_TEST_SUITE_END(); + From 6e1740b14ad3258a3b9835b1349fba8e8e9b4a3b Mon Sep 17 00:00:00 2001 From: nilayjain Date: Thu, 16 Jun 2016 03:48:46 +0000 Subject: [PATCH 02/14] not working right now --- .../edge_boxes/feature_extraction_impl.hpp | 217 +++++++++--------- .../methods/edge_boxes/feature_parameters.hpp | 89 +++++++ 2 files changed, 197 insertions(+), 109 deletions(-) create mode 100644 src/mlpack/methods/edge_boxes/feature_parameters.hpp diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index c7d7b87a980..c053787a4bf 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -16,10 +16,9 @@ namespace structured_tree { template StructuredForests:: -StructuredForests(const std::map inMap) - : options(std::move(inMap)) +StructuredForests() { - + // to do. } template @@ -65,8 +64,8 @@ LoadData(MatType const &images, MatType const &boundaries,\ } template -arma::vec StructuredForests:: -GetFeatureDimension() +void StructuredForests:: +GetFeatureDimension(arma::vec &FtrDim) { /* shrink: amount to shrink channels @@ -74,7 +73,7 @@ GetFeatureDimension() n_cell: number of self similarity cells n_orient: number of orientations per gradient scale */ - arma::vec FtrDim(2); + FtrDim = arma::vec(2); const size_t shrink = this->options["shrink"]; const size_t p_size = this->options["p_size"]; @@ -97,14 +96,15 @@ GetFeatureDimension() const size_t n_ch = n_color_ch + n_grad_ch; FtrDim[0] = std::pow((p_size / shrink) , 2) * n_ch; FtrDim[1] = std::pow(n_cell , 2) * (std::pow (n_cell, 2) - 1) / 2 * n_ch; - return FtrDim; } template -arma::vec StructuredForests:: -DistanceTransform1D(arma::vec const &f, const size_t n, const double inf) +void StructuredForests:: +DistanceTransform1D(const arma::vec& f, const size_t n, const double inf, + arma::vec& d) { - arma::vec d(n), v(n), z(n + 1); + arma::vec v(n), z(n + 1); + d = arma::vec(n); size_t k = 0; v[0] = 0.0; z[0] = -inf; @@ -143,7 +143,8 @@ DistanceTransform2D(MatType &im, const double inf) for (size_t x = 0; x < im.n_cols; ++x) { f.subvec(0, im.n_rows - 1) = im.col(x); - arma::vec d = this->DistanceTransform1D(f, im.n_rows, inf); + arma::vec d; + this->DistanceTransform1D(f, im.n_rows, inf, d); im.col(x) = d; } @@ -151,30 +152,31 @@ DistanceTransform2D(MatType &im, const double inf) for (size_t y = 0; y < im.n_rows; y++) { f.subvec(0, im.n_cols - 1) = im.row(y).t(); - arma::vec d = this->DistanceTransform1D(f, im.n_cols, inf); + arma::vec d; + this->DistanceTransform1D(f, im.n_cols, inf, d); im.row(y) = d.t(); } } /* euclidean distance transform of binary image using squared distance */ template -MatType StructuredForests:: -DistanceTransformImage(MatType const &im, double on) +void StructuredForests:: +DistanceTransformImage(const MatType& im, double on, MatType& out) { //need a large value but not infinity. double inf = 999999.99; MatType out = MatType(im.n_rows, im.n_cols, arma::fill::zeros); out.elem( find(im != on) ).fill(inf); this->DistanceTransform2D(out, inf); - return out; } template -CubeType StructuredForests:: -CopyMakeBorder(CubeType const &InImage, size_t top, - size_t left, size_t bottom, size_t right) +void StructuredForests:: +CopyMakeBorder(const CubeType& InImage, size_t top, + size_t left, size_t bottom, size_t right + CubeType& OutImage) { - CubeType OutImage(InImage.n_rows + top + bottom, InImage.n_cols + left + right, InImage.n_slices); + OutImage = MatType(InImage.n_rows + top + bottom, InImage.n_cols + left + right, InImage.n_slices); for(size_t i = 0; i < InImage.n_slices; ++i) { @@ -207,12 +209,11 @@ CopyMakeBorder(CubeType const &InImage, size_t top, } } - return OutImage; } template -CubeType StructuredForests:: -RGB2LUV(CubeType const &InImage) +void StructuredForests:: +RGB2LUV(const CubeType& InImage, CubeType OutImage) { //assert type is double or float. double a, y0, maxi; @@ -266,7 +267,7 @@ RGB2LUV(CubeType const &InImage) nz = 1.0 / ( xyz.slice(0) + (15 * xyz.slice(1) ) + (3 * xyz.slice(2) + 1e-35)); - CubeType OutImage(InImage.n_rows, InImage.n_cols, InImage.n_slices); + OutImage = MatType(InImage.n_rows, InImage.n_cols, InImage.n_slices); for(size_t j = 0; j < xyz.n_cols; ++j) { @@ -280,17 +281,16 @@ RGB2LUV(CubeType const &InImage) - 13 * 0.197833) + 88 * maxi; OutImage.slice(2) = OutImage.slice(0) % (13 * 9 * (xyz.slice(1) % nz) \ - 13 * 0.468331) + 134 * maxi; - - return OutImage; } /*implement this function in a column major order.*/ template -MatType StructuredForests:: -bilinearInterpolation(MatType const &src, - size_t height, size_t width) +void StructuredForests:: +BilinearInterpolation(MatType const &src, + size_t height, size_t width, + MatType dst) { - MatType dst(height, width); + dst = MatType(height, width); double const x_ratio = static_cast((src.n_cols - 1)) / width; double const y_ratio = static_cast((src.n_rows - 1)) / height; for(size_t row = 0; row != dst.n_rows; ++row) @@ -313,20 +313,19 @@ bilinearInterpolation(MatType const &src, y_cross_x * src(y + 1, x + 1); } } - - return dst; } template -CubeType StructuredForests:: -sepFilter2D(CubeType &InOutImage, arma::vec &kernel, size_t radius) +void StructuredForests:: +SepFilter2D(CubeType &InOutImage, arma::vec &kernel, size_t radius) { - CubeType OutImage = this->CopyMakeBorder(InOutImage, radius, radius, radius, radius); + CubeType OutImage; + this->CopyMakeBorder(InOutImage, radius, radius, radius, radius, OutImage); arma::vec row_res, col_res; // reverse InOutImage and OutImage to avoid making an extra matrix. // InImage is renamed to InOutImage in this function for this reason only. - arma::mat k_mat = kernel * kernel.t(); + MatType k_mat = kernel * kernel.t(); for(size_t k = 0; k < OutImage.n_slices; ++k) { for(size_t j = radius; j < OutImage.n_cols - radius; ++j) @@ -334,21 +333,22 @@ sepFilter2D(CubeType &InOutImage, arma::vec &kernel, size_t radius) for(size_t i = radius; i < OutImage.n_rows - radius; ++i) { InOutImage(i - radius, j - radius, k) = - arma::accu(OutImage.slice(k).submat(i - radius, j - radius, i + radius, j + radius) % k_mat); + arma::accu(OutImage.slice(k)\ + .submat(i - radius, j - radius,\ + i + radius, j + radius) % k_mat); } } } - return InOutImage; } template -CubeType StructuredForests:: +void StructuredForests:: ConvTriangle(CubeType &InImage, size_t radius) { if (radius == 0) { - return InImage; + //nothing to do } else if (radius <= 1) { @@ -356,7 +356,7 @@ ConvTriangle(CubeType &InImage, size_t radius) arma::vec kernel = {1 , p, 1}; kernel /= (p + 2); - return this->sepFilter2D(InImage, kernel, radius); + this->sepFilter2D(InImage, kernel, radius); } else { @@ -372,7 +372,7 @@ ConvTriangle(CubeType &InImage, size_t radius) kernel(i) = r--; kernel /= std::pow(radius + 1, 2); - return this->sepFilter2D(InImage, kernel, radius); + this->sepFilter2D(InImage, kernel, radius); } } @@ -380,11 +380,11 @@ ConvTriangle(CubeType &InImage, size_t radius) //finds max numbers on cube axis and returns max values, // also stores the locations of max values in Location template -MatType StructuredForests:: -MaxAndLoc(CubeType &mag, arma::umat &Location) const +void StructuredForests:: +MaxAndLoc(CubeType &mag, arma::umat &Location, CubeType& MaxVal) const { /*Vectorize this function after prototype works*/ - MatType MaxVal(Location.n_rows, Location.n_cols); + MaxVal = MatType(Location.n_rows, Location.n_cols); for(size_t i = 0; i < mag.n_rows; ++i) { for(size_t j = 0; j < mag.n_cols; ++j) @@ -402,12 +402,11 @@ MaxAndLoc(CubeType &mag, arma::umat &Location) const } } } - return MaxVal; } template void StructuredForests:: -Gradient(CubeType const &InImage, +Gradient(const CubeType &InImage, MatType &Magnitude, MatType &Orientation) { @@ -468,14 +467,14 @@ Gradient(CubeType const &InImage, } arma::umat Location(InImage.n_rows, InImage.n_cols); - Magnitude = this->MaxAndLoc(mag, Location); + this->MaxAndLoc(mag, Location, Magnitude); if(grd_norm_rad != 0) { //we have to do this ugly thing, or override ConvTriangle // and sepFilter2D methods. CubeType mag2(InImage.n_rows, InImage.n_cols, 1); mag2.slice(0) = Magnitude; - mag2 = this->ConvTriangle(mag2, grd_norm_rad); + this->ConvTriangle(mag2, grd_norm_rad); Magnitude = Magnitude / (mag2.slice(0) + 0.01); } MatType dx_mat(dx.n_rows, dx.n_cols),\ @@ -503,10 +502,11 @@ Gradient(CubeType const &InImage, } template -CubeType StructuredForests:: -Histogram(MatType const &Magnitude, - MatType const &Orientation, - size_t downscale, size_t interp) +void StructuredForests:: +Histogram(const MatType& Magnitude, + const MatType& Orientation, + size_t downscale, size_t interp, + CubeType& HistArr) { //i don't think this function can be vectorized. @@ -518,7 +518,7 @@ Histogram(MatType const &Magnitude, double o_range, o; o_range = arma::datum::pi / n_orient; - CubeType HistArr(n_rbin, n_cbin, n_orient); + HistArr = CubeType(n_rbin, n_cbin, n_orient); HistArr.zeros(); size_t r, c, o1, o2; @@ -549,15 +549,14 @@ Histogram(MatType const &Magnitude, for (size_t i = 0; i < HistArr.n_slices; ++i) HistArr.slice(i) = arma::square(HistArr.slice(i)); - - return HistArr; } template void StructuredForests:: -GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch) +GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch) { - CubeType luv = this->RGB2LUV(InImage); + CubeType luv; + this->RGB2LUV(InImage, luv); const size_t shrink = this->options["shrink"]; const size_t n_orient = this->options["n_orient"]; @@ -572,8 +571,8 @@ GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch) size_t slice_idx = 0; for( slice_idx = 0; slice_idx < luv.n_slices; ++slice_idx) - channels.slice(slice_idx) - = this->bilinearInterpolation(luv.slice(slice_idx), (size_t)rsize, (size_t)csize); + this->BilinearInterpolation(luv.slice(slice_idx), (size_t)rsize, (size_t)csize + channels.slice(slice_idx)); double scale = 0.5; @@ -585,10 +584,9 @@ GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch) for( slice_idx = 0; slice_idx < luv.n_slices; ++slice_idx) { - img.slice(slice_idx) = - this->bilinearInterpolation(luv.slice(slice_idx), - (luv.n_rows * scale), - (luv.n_cols * scale) ); + this->BilinearInterpolation(luv.slice(slice_idx), + (luv.n_rows * scale), (luv.n_cols * scale) + img.slice(slice_idx)); } CubeType OutImage = this->ConvTriangle(img, grd_smooth_rad); @@ -603,12 +601,11 @@ GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch) CubeType Hist = this->Histogram(Magnitude, Orientation, downscale, 0); - channels.slice(slice_idx) = - bilinearInterpolation( Magnitude, rsize, csize); + BilinearInterpolation( Magnitude, rsize, csize, channels.slice(slice_idx)); slice_idx++; for(size_t i = 0; i < InImage.n_slices; ++i) - channels.slice(i + slice_idx) = - bilinearInterpolation( Magnitude, rsize, csize); + BilinearInterpolation( Magnitude, rsize, csize,\ + channels.slice(i + slice_idx)); slice_idx += 3; scale += 0.5; } @@ -634,23 +631,24 @@ GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch) } template -CubeType StructuredForests:: -ViewAsWindows(CubeType const &channels, arma::umat const &loc) +void StructuredForests:: +ViewAsWindows(const CubeType& channels, arma::umat const &loc, + CubeType& features) { // 500 for pos_loc, and 500 for neg_loc. // channels = 160, 240, 13. - CubeType features = CubeType(16, 16, 1000 * 13); + features = CubeType(16, 16, 1000 * 13); const size_t patchSize = 16; const size_t p = patchSize / 2; //increase the channel boundary to protect error against image boundaries. - CubeType inc_ch = this->CopyMakeBorder(channels, p, p, p, p); + CubeType inc_ch; + this->CopyMakeBorder(channels, p, p, p, p, inc_ch); for (size_t i = 0, channel = 0; i < loc.n_rows; ++i) { size_t x = loc(i, 0); size_t y = loc(i, 1); /*(x,y) in channels, is ((x+p), (y+p)) in inc_ch*/ - //cout << "(x,y) = " << x << " " << y << endl; CubeType patch = inc_ch.tube((x + p) - p, (y + p) - p,\ (x + p) + p - 1, (y + p) + p - 1); // since each patch has 13 channel we have to increase the index by 13 @@ -658,21 +656,17 @@ ViewAsWindows(CubeType const &channels, arma::umat const &loc) //cout <<"patch size = " << arma::size(patch) << endl; features.slices(channel, channel + 12) = patch; - //cout << "sahi hai " << endl; channel += 13; - } - //cout << "successfully returned. . ." << endl; - return features; } template -CubeType StructuredForests:: -Rearrange(CubeType const &channels) +void StructuredForests:: +Rearrange(CubeType const &channels, CubeType& ch) { //we do (16,16,13*1000) to 256, 1000, 13, in vectorized code. - CubeType ch = CubeType(256, 1000, 13); - for(size_t i = 0; i < 1000; i++) + ch = CubeType(256, 1000, 13); + for(size_t i = 0; i < 1000; ++i) { //MatType m(256, 13); for(size_t j = 0; j < 13; ++j) @@ -682,23 +676,25 @@ Rearrange(CubeType const &channels) ch.slice(sl).col(i) = arma::vectorise(channels.slice(i * j)); } } - return ch; } // returns 256 * 1000 * 13 dimension features. template -CubeType StructuredForests:: -GetRegFtr(CubeType const &channels, arma::umat const &loc) +void StructuredForests:: +GetRegFtr(const CubeType& channels,const arma::umat& loc + CubeType& RegFtr) { int shrink = this->options["shrink"]; int p_size = this->options["p_size"] / shrink; - CubeType wind = this->ViewAsWindows(channels, loc); - return this->Rearrange(wind); + CubeType wind; + this->ViewAsWindows(channels, loc, wind); + this->Rearrange(wind, RegFtr); } template CubeType StructuredForests:: -PDist(CubeType const &features, arma::uvec const &grid_pos) +PDist(const CubeType& features, const arma::uvec& grid_pos, + CubeType& output) { // size of DestArr: // InImage.n_rows * (InImage.n_rows - 1)/2 * InImage.n_slices @@ -706,7 +702,7 @@ PDist(CubeType const &features, arma::uvec const &grid_pos) //python: input: (716, 256, 13) --->(716, 25, 13) ; output: (716, 300, 13). //input features : 256,1000,13; output: 300, 1000, 13 - CubeType output(300, 1000, 13); + output = CubeType(300, 1000, 13); for(size_t k = 0; k < features.n_slices; ++k) { size_t r_idx = 0; @@ -720,13 +716,13 @@ PDist(CubeType const &features, arma::uvec const &grid_pos) } } } - return output; } //returns 300,1000,13 dimension features. template -CubeType StructuredForests:: -GetSSFtr(CubeType const &channels, arma::umat const &loc) +void StructuredForests:: +GetSSFtr(CubeType const &channels, arma::umat const &loc + CubeType SSFtr) { const size_t shrink = this->options["shrink"]; const size_t p_size = this->options["p_size"] / shrink; @@ -752,15 +748,17 @@ GetSSFtr(CubeType const &channels, arma::umat const &loc) } } - CubeType wind = this->ViewAsWindows(channels, loc); - CubeType re_wind = this->Rearrange(wind); - - return this->PDist(re_wind, grid_pos); + CubeType wind; + this->ViewAsWindows(channels, loc, wind); + CubeType re_wind; + this->Rearrange(wind, re_wind); + this->PDist(re_wind, grid_pos, SSFtr); } template -arma::field StructuredForests:: -GetFeatures(MatType const &image, arma::umat &loc) +void StructuredForests:: +GetFeatures(const MatType &image, arma::umat &loc, + CubeType& RegFtr, CubeType& SSFtr) { const size_t row_size = this->options["row_size"]; const size_t col_size = this->options["col_size"]; @@ -776,7 +774,8 @@ GetFeatures(MatType const &image, arma::umat &loc) (i + 1) * row_size - 1, col_size - 1); } - CubeType OutImage = this->CopyMakeBorder(InImage, 0, 0, bottom, right); + CubeType OutImage; + this->CopyMakeBorder(InImage, 0, 0, bottom, right, OutImage); const size_t num_channels = 13; const size_t shrink = this->options["shrink"]; @@ -791,12 +790,8 @@ GetFeatures(MatType const &image, arma::umat &loc) loc /= shrink; - CubeType reg_ftr = this->GetRegFtr(reg_ch, loc); - CubeType ss_ftr = this->GetSSFtr(ss_ch, loc); - arma::field F(2,1); - F(0,0) = reg_ftr; - F(1,0) = ss_ftr; - return F; + this->GetRegFtr(reg_ch, loc, RegFtr); + this->GetSSFtr(ss_ch, loc, SSFtr); } template @@ -817,7 +812,8 @@ PrepareData(MatType const &InputData) // g_rad = radius of ground truth patches. const size_t p_rad = p_size / 2, g_rad = g_size / 2; - arma::vec FtrDim = this->GetFeatureDimension(); + arma::vec FtrDim; + this->GetFeatureDimension(FtrDim); const size_t n_ftr_dim = FtrDim(0) + FtrDim(1); const size_t n_smp_ftr_dim = (size_t)(n_ftr_dim * fraction); @@ -856,8 +852,9 @@ PrepareData(MatType const &InputData) //int n_patches_per_gt = (int) (ceil( (float)n_pos / num_images )); const size_t n_patches_per_gt = 500; //cout << "n_patches_per_gt = " << n_patches_per_gt << endl; - MatType dis = arma::sqrt( this->DistanceTransformImage(bnds, 1) ); - MatType dis2 = dis; + MatType dis; + this->DistanceTransformImage(bnds, 1, dis) + dis = arma::sqrt(dis); //dis.transform( [](double val, const int& g_rad) { return (double)(val < g_rad); } ); //dis2.transform( [](double val, const int& g_rad) { return (double)(val >= g_rad); } ); //dis.elem( arma::find(dis >= g_rad) ).zeros(); @@ -888,7 +885,7 @@ PrepareData(MatType const &InputData) // cout << "num patches = " << n_patches_per_gt << " num elements + = " << pos_loc.n_elem\ // << " num elements - = " << neg_loc.n_elem << " dis.size " << dis.n_elem << endl; - //Field F contains reg_ftr and ss_ftr for one image. + CubeType SSFtr, RegFtr; arma::field F = this->GetFeatures(img, loc); //randomly sample 70 values each from reg_ftr and ss_ftr. /* @@ -903,8 +900,10 @@ PrepareData(MatType const &InputData) // have to do this or we can overload the CopyMakeBorder to support MatType. s.slice(0) = segs; - CubeType in_segs = this->CopyMakeBorder(s, g_rad, - g_rad, g_rad, g_rad); + CubeType in_segs; + this->CopyMakeBorder(s, g_rad, + g_rad, g_rad, g_rad, in_segs); + for(size_t i = 0; i < loc.n_rows; ++i) { size_t x = loc(i, 0); size_t y = loc(i, 1); diff --git a/src/mlpack/methods/edge_boxes/feature_parameters.hpp b/src/mlpack/methods/edge_boxes/feature_parameters.hpp new file mode 100644 index 00000000000..567fc6c1479 --- /dev/null +++ b/src/mlpack/methods/edge_boxes/feature_parameters.hpp @@ -0,0 +1,89 @@ +/** + * @file feature_extraction_impl.hpp + * @author Nilay Jain + * + * Implementation of feature parameter class. + */ + +#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP +#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP + +namespace mlpack { +namespace structured_tree { + +//This class holds all the fields for the FeatureExtraction class. +class FeatureParameters +{ + public: + + FeatureParameters(); //default constructor + + void NumImages(size_t value) { numImages = value; } + size_t NumImages() const { return numImages; } + + void RowSize(size_t value) { rowSize = value; } + size_t RowSize() const { return rowSize; } + + void ColSize(size_t value) { colSize = value; } + size_t ColSize() const { return colSize; } + + void RGBD(size_t value) { rgbd = value; } + size_t RGBD() const { return rgbd; } + + void Shrink(size_t value) { shrink = value; } + size_t Shrink() const { return shrink; } + + void NumOrient(size_t value) { numOrient = value; } + size_t NumOrient() const { return numOrient; } + + void GrdSmoothRad(size_t value) { grdSmoothRad = value; } + size_t GrdSmoothRad() const { return grdSmoothRad; } + + void GrdNormRad(size_t value) { grdNormRad = value; } + size_t GrdNormRad() const { return grdNormRad; } + + void RegSmoothRad(size_t value) { regSmoothRad = value; } + size_t RegSmoothRad() const { return regSmoothRad; } + + void SSSmoothRad(size_t value) { ssSmoothRad = value; } + size_t SSSmoothRad() const { return ssSmoothRad; } + + void PSize(size_t value) { pSize = value; } + size_t PSize() const { return pSize; } + + void GSize(size_t value) { gSize = value; } + size_t GSize() const { return gSize; } + + void NumCell(size_t value) { numCell = value; } + size_t NumCell() const { return numCell; } + + void NumPos(size_t value) { numPos = value; } + size_t NumPos() const { return numPos; } + + void NumNeg(size_t value) { numNeg = value; } + size_t NumNeg() const { return numNeg; } + + void Fraction(double value) { fraction = value; } + double Fraction() const { return fraction; } + + private: + size_t numImages; + size_t rowSize; + size_t colSize; + size_t rgbd; + size_t shrink; + size_t numOrient; + size_t grdSmoothRad; + size_t grdNormRad; + size_t regSmoothRad; + size_t ssSmoothRad; + size_t pSize; + size_t gSize; + size_t numCell; + size_t numPos; + size_t numNeg; + double numCell; +}; + +} +} From 5bab9539341f769809af4be48ce58f7ae88b36f8 Mon Sep 17 00:00:00 2001 From: Jain Date: Fri, 17 Jun 2016 17:52:19 +0530 Subject: [PATCH 03/14] working on it --- .../methods/edge_boxes/feature_extraction.hpp | 24 +- .../edge_boxes/feature_extraction_impl.hpp | 455 ++++++++++-------- .../methods/edge_boxes/feature_parameters.hpp | 7 +- 3 files changed, 266 insertions(+), 220 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp index 68cae5f4acb..ed371f73c97 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -16,6 +16,8 @@ namespace structured_tree { template class StructuredForests { + private: + FeatureParameters params; public: @@ -23,23 +25,23 @@ class StructuredForests std::map options; - StructuredForests(const std::map inMap); - - MatType LoadData(MatType const &images, MatType const &boundaries,\ - MatType const &segmentations); + void StructuredForests(FeatureParameters F); +/* MatType LoadData(MatType const &images, MatType const &boundaries,\ + MatType const &segmentations);*/ void PrepareData(MatType const &InputData); - arma::vec GetFeatureDimension(); - - arma::vec DistanceTransform1D(arma::vec const &f, const size_t n,\ - const double inf); + void GetFeatureDimension(arma::vec &FtrDim); - void DistanceTransform2D(MatType &im, const double inf); + void DistanceTransform1D(const arma::vec& f, const size_t n,\ + const double inf, arma::vec& d); - MatType DistanceTransformImage(MatType const &im, double on); + void DistanceTransform2D(MatType &Im, const double inf); - arma::field GetFeatures(MatType const &image, arma::umat &loc); + void DistanceTransformImage(const MatType& Im, double on, MatType& Out); + + void GetFeatures(const MatType &Image, arma::umat &loc, + CubeType& RegFtr, CubeType& SSFtr); CubeType CopyMakeBorder(CubeType const &InImage, size_t top, size_t left, size_t bottom, size_t right); diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index c053787a4bf..5d6abb98a5c 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -1,11 +1,11 @@ /** - * @file feature_extraction_impl.hpp + * @file feature_extraction_Impl.hpp * @author Nilay Jain * * Implementation of feature extraction methods. */ -#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP -#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP +#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_ImPL_HPP +#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_ImPL_HPP #include "feature_extraction.hpp" @@ -14,90 +14,98 @@ namespace mlpack { namespace structured_tree { +/** + * Constructor: stores all the parameters in an object + * of feature_parameters class. + */ template StructuredForests:: -StructuredForests() +StructuredForests(FeatureParameters F) { // to do. + params = F; + //check if this works. + std::cout << params.numImages << std::endl; } +/* template MatType StructuredForests:: -LoadData(MatType const &images, MatType const &boundaries,\ +LoadData(MatType const &Images, MatType const &boundaries,\ MatType const &segmentations) { - const size_t num_images = this->options["num_images"]; - const size_t row_size = this->options["row_size"]; - const size_t col_size = this->options["col_size"]; - MatType input_data(num_images * row_size * 5, col_size); + const size_t num_Images = this->params.num_Images; + const size_t rowSize = this->params.rowSize; + const size_t colSize = this->params.colSize; + MatType input_data(num_Images * rowSize * 5, colSize); // we store the input data as follows: - // images (3), boundaries (1), segmentations (1). - size_t loop_iter = num_images * 5; + // Images (3), boundaries (1), segmentations (1). + size_t loop_iter = num_Images * 5; size_t row_idx = 0; size_t col_i = 0, col_s = 0, col_b = 0; for(size_t i = 0; i < loop_iter; ++i) { if (i % 5 == 4) { - input_data.submat(row_idx, 0, row_idx + row_size - 1,\ - col_size - 1) = MatType(segmentations.colptr(col_s),\ - col_size, row_size).t(); + input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ + colSize - 1) = MatType(segmentations.colptr(col_s),\ + colSize, rowSize).t(); ++col_s; } else if (i % 5 == 3) { - input_data.submat(row_idx, 0, row_idx + row_size - 1,\ - col_size - 1) = MatType(boundaries.colptr(col_b),\ - col_size, row_size).t(); + input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ + colSize - 1) = MatType(boundaries.colptr(col_b),\ + colSize, rowSize).t(); ++col_b; } else { - input_data.submat(row_idx, 0, row_idx + row_size - 1,\ - col_size - 1) = MatType(images.colptr(col_i), - col_size, row_size).t(); + input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ + colSize - 1) = MatType(Images.colptr(col_i), + colSize, rowSize).t(); ++col_i; } - row_idx += row_size; + row_idx += rowSize; } return input_data; } +*/ +/** + * Get DImensions of Features + * @param FtrDIm Output vector that contains the result + */ template void StructuredForests:: -GetFeatureDimension(arma::vec &FtrDim) +GetFeatureDImension(arma::vec &FtrDIm) { - /* - shrink: amount to shrink channels - p_size: size of image patches - n_cell: number of self similarity cells - n_orient: number of orientations per gradient scale - */ - FtrDim = arma::vec(2); + FtrDIm = arma::vec(2); - const size_t shrink = this->options["shrink"]; - const size_t p_size = this->options["p_size"]; - const size_t n_cell = this->options["n_cell"]; - const size_t rgbd = this->options["rgbd"]; - const size_t n_orient = this->options["n_orient"]; - /* - n_color_ch: number of color channels - n_grad_ch: number of gradient channels - n_ch: total number of channels - */ - size_t n_color_ch; - if (this->options["rgbd"] == 0) - n_color_ch = 3; + const size_t shrink = this->params.shrink; + const size_t pSize = this->params.pSize; + const size_t numCell = this->params.numCell; + const size_t rgbd = this->params.rgbd; + const size_t numOrient = this->params.numOrient; + + size_t nColorCh; + if (this->params.rgbd == 0) + nColorCh = 3; else - n_color_ch = 4; + nColorCh = 4; - const size_t n_grad_ch = 2 * (1 + n_orient); - - const size_t n_ch = n_color_ch + n_grad_ch; - FtrDim[0] = std::pow((p_size / shrink) , 2) * n_ch; - FtrDim[1] = std::pow(n_cell , 2) * (std::pow (n_cell, 2) - 1) / 2 * n_ch; + const size_t nCh = nColorCh + 2 * (1 + numOrient); + FtrDIm[0] = std::pow((pSize / shrink) , 2) * nCh; + FtrDIm[1] = std::pow(numCell , 2) * (std::pow (numCell, 2) - 1) / 2 * nCh; } +/** + * Computes distance transform of 1D vector f. + * @param f input vector whose distance transform is to be found. + * @param n size of the Output vector to be made. + * @param inf a large double value. + * @param d Output vector which stores distance transform of f. + */ template void StructuredForests:: DistanceTransform1D(const arma::vec& f, const size_t n, const double inf, @@ -134,42 +142,62 @@ DistanceTransform1D(const arma::vec& f, const size_t n, const double inf, return d; } +/** + * Computes distance transform of a 2D array + * @param Im input array whose distance transform is to be found. + * @param inf a large double value. + */ + template void StructuredForests:: -DistanceTransform2D(MatType &im, const double inf) +DistanceTransform2D(MatType &Im, const double inf) { - arma::vec f(std::max(im.n_rows, im.n_cols)); + arma::vec f(std::max(Im.n_rows, Im.n_cols)); // transform along columns - for (size_t x = 0; x < im.n_cols; ++x) + for (size_t x = 0; x < Im.n_cols; ++x) { - f.subvec(0, im.n_rows - 1) = im.col(x); + f.subvec(0, Im.n_rows - 1) = Im.col(x); arma::vec d; - this->DistanceTransform1D(f, im.n_rows, inf, d); - im.col(x) = d; + this->DistanceTransform1D(f, Im.n_rows, inf, d); + Im.col(x) = d; } // transform along rows - for (size_t y = 0; y < im.n_rows; y++) + for (size_t y = 0; y < Im.n_rows; y++) { - f.subvec(0, im.n_cols - 1) = im.row(y).t(); + f.subvec(0, Im.n_cols - 1) = Im.row(y).t(); arma::vec d; - this->DistanceTransform1D(f, im.n_cols, inf, d); - im.row(y) = d.t(); + this->DistanceTransform1D(f, Im.n_cols, inf, d); + Im.row(y) = d.t(); } } -/* euclidean distance transform of binary image using squared distance */ +/** + * euclidean distance transform of binary Image using squared distance + * @param Im Input binary Image whose distance transform is to be found. + * @param on if on == 1, 1 is taken as boundaries and vice versa. + * @param Out Output Image. + */ template void StructuredForests:: -DistanceTransformImage(const MatType& im, double on, MatType& out) +DistanceTransformImage(const MatType& Im, double on, MatType& Out) { //need a large value but not infinity. double inf = 999999.99; - MatType out = MatType(im.n_rows, im.n_cols, arma::fill::zeros); - out.elem( find(im != on) ).fill(inf); - this->DistanceTransform2D(out, inf); + MatType Out = MatType(Im.n_rows, Im.n_cols, arma::fill::zeros); + Out.elem( find(Im != on) ).fill(inf); + this->DistanceTransform2D(Out, inf); } +/** + * Makes a reflective border around an Image. + * @param InImage Image which we have to make border around. + * @param top border length at top. + * @param left border length at left. + * @param bottom border length at bottom. + * @param right border length at right. + * @param OutImage Output Image. + */ template void StructuredForests:: CopyMakeBorder(const CubeType& InImage, size_t top, @@ -211,6 +239,11 @@ CopyMakeBorder(const CubeType& InImage, size_t top, } } +/** + * Converts an Image in RGB color space to LUV color space. + * @param InImage Input Image in RGB color space. + * @param OutImage Ouptut Image in LUV color space. + */ template void StructuredForests:: RGB2LUV(const CubeType& InImage, CubeType OutImage) @@ -255,12 +288,6 @@ RGB2LUV(const CubeType& InImage, CubeType OutImage) xyz.slice(0)(i) = 0.430574 * r + 0.341550 * g + 0.178325 * b; xyz.slice(1)(i) = 0.222015 * r + 0.706655 * g + 0.071330 * b; xyz.slice(2)(i) = 0.020183 * r + 0.129553 * g + 0.939180 * b; - - /* - xyz.slice(0)(i) = 0.430574 * r + 0.341550 * g + 0.178325 * b; - xyz.slice(1)(i) = 0.222015 * r + 0.706655 * g + 0.129553 * b; - xyz.slice(2)(i) = 0.020183 * r + 0.071330 * g + 0.939180 * b; - */ } MatType nz(InImage.n_rows, InImage.n_cols); @@ -283,7 +310,14 @@ RGB2LUV(const CubeType& InImage, CubeType OutImage) - 13 * 0.468331) + 134 * maxi; } -/*implement this function in a column major order.*/ +/** + * Resizes the Image to the given size using Bilinear Interpolation + * @param src Input Image + * @param height Height of Output Image. + * @param width Width Out Output Image. + * @param dst Output Image resized to (height, width) + */ +/*Implement this function in a column major order.*/ template void StructuredForests:: BilinearInterpolation(MatType const &src, @@ -315,9 +349,16 @@ BilinearInterpolation(MatType const &src, } } +/** + * Applies a separable linear filter to an Image + * @param InOutImage Input/Output Contains the input Image, The final filtered Image is + * stored in this param. + * @param kernel Input Kernel vector to be applied on Image. + * @param radius amount, the Image should be padded before applying filter. + */ template void StructuredForests:: -SepFilter2D(CubeType &InOutImage, arma::vec &kernel, size_t radius) +SepFilter2D(CubeType &InOutImage, const arma::vec& kernel, const size_t radius) { CubeType OutImage; this->CopyMakeBorder(InOutImage, radius, radius, radius, radius, OutImage); @@ -342,9 +383,14 @@ SepFilter2D(CubeType &InOutImage, arma::vec &kernel, size_t radius) } +/** + * Applies a triangle filter on an Image. + * @param InImage Input/Output Image on which filter is applied. + * @param radius Decides the size of kernel to be applied on Image. + */ template void StructuredForests:: -ConvTriangle(CubeType &InImage, size_t radius) +ConvTriangle(CubeType &InImage, const size_t radius) { if (radius == 0) { @@ -390,7 +436,7 @@ MaxAndLoc(CubeType &mag, arma::umat &Location, CubeType& MaxVal) const for(size_t j = 0; j < mag.n_cols; ++j) { /*can use -infinity here*/ - double max = std::numeric_limits::min(); + double max = std::numeric_lImits::min(); for(size_t k = 0; k < mag.n_slices; ++k) { if(mag(i, j, k) > max) @@ -410,7 +456,7 @@ Gradient(const CubeType &InImage, MatType &Magnitude, MatType &Orientation) { - const size_t grd_norm_rad = this->options["grd_norm_rad"]; + const size_t grdNormRad = this->params.grdNormRad; CubeType dx(InImage.n_rows, InImage.n_cols, InImage.n_slices), dy(InImage.n_rows, InImage.n_cols, InImage.n_slices); @@ -420,7 +466,7 @@ Gradient(const CubeType &InImage, /* From MATLAB documentation: [FX,FY] = gradient(F), where F is a matrix, returns the - x and y components of the two-dimensional numerical gradient. + x and y components of the two-dImensional numerical gradient. FX corresponds to ∂F/∂x, the differences in x (horizontal) direction. FY corresponds to ∂F/∂y, the differences in the y (vertical) direction. */ @@ -468,13 +514,13 @@ Gradient(const CubeType &InImage, arma::umat Location(InImage.n_rows, InImage.n_cols); this->MaxAndLoc(mag, Location, Magnitude); - if(grd_norm_rad != 0) + if(grdNormRad != 0) { //we have to do this ugly thing, or override ConvTriangle // and sepFilter2D methods. CubeType mag2(InImage.n_rows, InImage.n_cols, 1); mag2.slice(0) = Magnitude; - this->ConvTriangle(mag2, grd_norm_rad); + this->ConvTriangle(mag2, grdNormRad); Magnitude = Magnitude / (mag2.slice(0) + 0.01); } MatType dx_mat(dx.n_rows, dx.n_cols),\ @@ -510,15 +556,15 @@ Histogram(const MatType& Magnitude, { //i don't think this function can be vectorized. - //n_orient: number of orientations per gradient scale - const size_t n_orient = this->options["n_orient"]; - //size of HistArr: n_rbin * n_cbin * n_orient . . . (create in caller...) + //numOrient: number of orientations per gradient scale + const size_t numOrient = this->params.numOrient; + //size of HistArr: n_rbin * n_cbin * numOrient . . . (create in caller...) const size_t n_rbin = (Magnitude.n_rows + downscale - 1) / downscale; const size_t n_cbin = (Magnitude.n_cols + downscale - 1) / downscale; double o_range, o; - o_range = arma::datum::pi / n_orient; + o_range = arma::datum::pi / numOrient; - HistArr = CubeType(n_rbin, n_cbin, n_orient); + HistArr = CubeType(n_rbin, n_cbin, numOrient); HistArr.zeros(); size_t r, c, o1, o2; @@ -532,14 +578,14 @@ Histogram(const MatType& Magnitude, if( interp != 0) { o = Orientation(i, j) / o_range; - o1 = ((size_t) o) % n_orient; - o2 = (o1 + 1) % n_orient; + o1 = ((size_t) o) % numOrient; + o2 = (o1 + 1) % numOrient; HistArr(r, c, o1) += Magnitude(i, j) * (1 + (int)o - o); HistArr(r, c, o2) += Magnitude(i, j) * (o - (int) o); } else { - o1 = (size_t) (Orientation(i, j) / o_range + 0.5) % n_orient; + o1 = (size_t) (Orientation(i, j) / o_range + 0.5) % numOrient; HistArr(r, c, o1) += Magnitude(i, j); } } @@ -558,10 +604,10 @@ GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch) CubeType luv; this->RGB2LUV(InImage, luv); - const size_t shrink = this->options["shrink"]; - const size_t n_orient = this->options["n_orient"]; - const size_t grd_smooth_rad = this->options["grd_smooth_rad"]; - const size_t grd_norm_rad = this->options["grd_norm_rad"]; + const size_t shrink = this->params.shrink; + const size_t numOrient = this->params.numOrient; + const size_t grdSmoothRad = this->params.grdSmoothRad; + const size_t grdNormRad = this->params.grdNormRad; const size_t num_channels = 13; const size_t rsize = luv.n_rows / shrink; const size_t csize = luv.n_cols / shrink; @@ -578,7 +624,7 @@ GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch) while(scale <= 1.0) { - CubeType img( (luv.n_rows * scale), + CubeType Img( (luv.n_rows * scale), (luv.n_cols * scale), luv.n_slices ); @@ -586,10 +632,10 @@ GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch) { this->BilinearInterpolation(luv.slice(slice_idx), (luv.n_rows * scale), (luv.n_cols * scale) - img.slice(slice_idx)); + Img.slice(slice_idx)); } - CubeType OutImage = this->ConvTriangle(img, grd_smooth_rad); + CubeType OutImage = this->ConvTriangle(Img, grdSmoothRad); MatType Magnitude(InImage.n_rows, InImage.n_cols), Orientation(InImage.n_rows, InImage.n_cols); @@ -611,22 +657,22 @@ GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch) } //cout << "size of channels: " << arma::size(channels) << endl; - double reg_smooth_rad, ss_smooth_rad; - reg_smooth_rad = this->options["reg_smooth_rad"] / (double) shrink; - ss_smooth_rad = this->options["ss_smooth_rad"] / (double) shrink; + double regSmoothRad, ssSmoothRad; + regSmoothRad = this->params.regSmoothRad / (double) shrink; + ssSmoothRad = this->params.ssSmoothRad / (double) shrink; - if (reg_smooth_rad > 1.0) - reg_ch = this->ConvTriangle(channels, (size_t) (std::round(reg_smooth_rad)) ); + if (regSmoothRad > 1.0) + reg_ch = this->ConvTriangle(channels, (size_t) (std::round(regSmoothRad)) ); else - reg_ch = this->ConvTriangle(channels, reg_smooth_rad); + reg_ch = this->ConvTriangle(channels, regSmoothRad); - if (ss_smooth_rad > 1.0) - ss_ch = this->ConvTriangle(channels, (size_t) (std::round(ss_smooth_rad)) ); + if (ssSmoothRad > 1.0) + ss_ch = this->ConvTriangle(channels, (size_t) (std::round(ssSmoothRad)) ); else - ss_ch = this->ConvTriangle(channels, ss_smooth_rad); + ss_ch = this->ConvTriangle(channels, ssSmoothRad); } @@ -635,12 +681,12 @@ void StructuredForests:: ViewAsWindows(const CubeType& channels, arma::umat const &loc, CubeType& features) { - // 500 for pos_loc, and 500 for neg_loc. + // 500 for posLoc, and 500 for negLoc. // channels = 160, 240, 13. features = CubeType(16, 16, 1000 * 13); const size_t patchSize = 16; const size_t p = patchSize / 2; - //increase the channel boundary to protect error against image boundaries. + //increase the channel boundary to protect error against Image boundaries. CubeType inc_ch; this->CopyMakeBorder(channels, p, p, p, p, inc_ch); for (size_t i = 0, channel = 0; i < loc.n_rows; ++i) @@ -678,14 +724,14 @@ Rearrange(CubeType const &channels, CubeType& ch) } } -// returns 256 * 1000 * 13 dimension features. +// returns 256 * 1000 * 13 dImension features. template void StructuredForests:: GetRegFtr(const CubeType& channels,const arma::umat& loc CubeType& RegFtr) { - int shrink = this->options["shrink"]; - int p_size = this->options["p_size"] / shrink; + int shrink = this->params.shrink; + int pSize = this->params.pSize / shrink; CubeType wind; this->ViewAsWindows(channels, loc, wind); this->Rearrange(wind, RegFtr); @@ -694,15 +740,15 @@ GetRegFtr(const CubeType& channels,const arma::umat& loc template CubeType StructuredForests:: PDist(const CubeType& features, const arma::uvec& grid_pos, - CubeType& output) + CubeType& Output) { // size of DestArr: // InImage.n_rows * (InImage.n_rows - 1)/2 * InImage.n_slices //find nC2 differences, for locations in the grid_pos. - //python: input: (716, 256, 13) --->(716, 25, 13) ; output: (716, 300, 13). - //input features : 256,1000,13; output: 300, 1000, 13 + //python: input: (716, 256, 13) --->(716, 25, 13) ; Output: (716, 300, 13). + //input features : 256,1000,13; Output: 300, 1000, 13 - output = CubeType(300, 1000, 13); + Output = CubeType(300, 1000, 13); for(size_t k = 0; k < features.n_slices; ++k) { size_t r_idx = 0; @@ -710,7 +756,7 @@ PDist(const CubeType& features, const arma::uvec& grid_pos, { for(size_t j = i + 1; j < grid_pos.n_elem; ++j) //loop length : 25 { - output.slice(k).row(r_idx) = features.slice(k).row(grid_pos(i)) + Output.slice(k).row(r_idx) = features.slice(k).row(grid_pos(i)) - features.slice(k).row(grid_pos(j)); ++r_idx; } @@ -718,32 +764,32 @@ PDist(const CubeType& features, const arma::uvec& grid_pos, } } -//returns 300,1000,13 dimension features. +//returns 300,1000,13 dImension features. template void StructuredForests:: GetSSFtr(CubeType const &channels, arma::umat const &loc CubeType SSFtr) { - const size_t shrink = this->options["shrink"]; - const size_t p_size = this->options["p_size"] / shrink; + const size_t shrink = this->params.shrink; + const size_t pSize = this->params.pSize / shrink; - //n_cell: number of self similarity cells - const size_t n_cell = this->options["n_cell"]; - const size_t half_cell_size = (size_t) round(p_size / (2.0 * n_cell)); + //numCell: number of self sImilarity cells + const size_t numCell = this->params.numCell; + const size_t half_cell_size = (size_t) round(pSize / (2.0 * numCell)); - arma::uvec g_pos(n_cell); - for(size_t i = 0; i < n_cell; ++i) + arma::uvec g_pos(numCell); + for(size_t i = 0; i < numCell; ++i) { - g_pos(i) = (size_t)round( (i + 1) * (p_size + 2 * half_cell_size \ - - 1) / (n_cell + 1.0) - half_cell_size); + g_pos(i) = (size_t)round( (i + 1) * (pSize + 2 * half_cell_size \ + - 1) / (numCell + 1.0) - half_cell_size); } - arma::uvec grid_pos(n_cell * n_cell); + arma::uvec grid_pos(numCell * numCell); size_t k = 0; - for(size_t i = 0; i < n_cell; ++i) + for(size_t i = 0; i < numCell; ++i) { - for(size_t j = 0; j < n_cell; ++j) + for(size_t j = 0; j < numCell; ++j) { - grid_pos(k) = g_pos(i) * p_size + g_pos(j); + grid_pos(k) = g_pos(i) * pSize + g_pos(j); ++k; } } @@ -757,33 +803,33 @@ GetSSFtr(CubeType const &channels, arma::umat const &loc template void StructuredForests:: -GetFeatures(const MatType &image, arma::umat &loc, +GetFeatures(const MatType &Image, arma::umat &loc, CubeType& RegFtr, CubeType& SSFtr) { - const size_t row_size = this->options["row_size"]; - const size_t col_size = this->options["col_size"]; - const size_t bottom = (4 - (image.n_rows / 3) % 4) % 4; - const size_t right = (4 - image.n_cols % 4) % 4; + const size_t rowSize = this->params.rowSize; + const size_t colSize = this->params.colSize; + const size_t bottom = (4 - (Image.n_rows / 3) % 4) % 4; + const size_t right = (4 - Image.n_cols % 4) % 4; //cout << "Botttom = " << bottom << " right = " << right << endl; - CubeType InImage(image.n_rows / 3, image.n_cols, 3); + CubeType InImage(Image.n_rows / 3, Image.n_cols, 3); for(size_t i = 0; i < 3; ++i) { - InImage.slice(i) = image.submat(i * row_size, 0, \ - (i + 1) * row_size - 1, col_size - 1); + InImage.slice(i) = Image.submat(i * rowSize, 0, \ + (i + 1) * rowSize - 1, colSize - 1); } CubeType OutImage; this->CopyMakeBorder(InImage, 0, 0, bottom, right, OutImage); const size_t num_channels = 13; - const size_t shrink = this->options["shrink"]; + const size_t shrink = this->params.shrink; const size_t rsize = OutImage.n_rows / shrink; const size_t csize = OutImage.n_cols / shrink; - /* this part gives double free or corruption out error - when executed for a second time */ + /* this part gives double free or corruption Out error + when executed for a second tIme */ CubeType reg_ch = CubeType(rsize, csize, num_channels); CubeType ss_ch = CubeType(rsize, csize, num_channels); this->GetShrunkChannels(InImage, reg_ch, ss_ch); @@ -794,99 +840,92 @@ GetFeatures(const MatType &image, arma::umat &loc, this->GetSSFtr(ss_ch, loc, SSFtr); } +/** + * This functions prepares the data, + * and extracts features, structured labels. + * @param: + */ + template void StructuredForests:: -PrepareData(MatType const &InputData) +PrepareData(const MatType& Images, const MatType& Boundaries,\ + const MatType& Segmentations) { - const size_t num_images = this->options["num_images"]; - const size_t n_tree = this->options["n_tree"]; - const size_t n_pos = this->options["n_pos"]; - const size_t n_neg = this->options["n_neg"]; - const double fraction = 0.25; - const size_t p_size = this->options["p_size"]; - const size_t g_size = this->options["g_size"]; - const size_t shrink = this->options["shrink"]; - const size_t row_size = this->options["row_size"]; - const size_t col_size = this->options["col_size"]; - // p_rad = radius of image patches. - // g_rad = radius of ground truth patches. - const size_t p_rad = p_size / 2, g_rad = g_size / 2; + const size_t numImages = this->params.numImages; + const size_t numTree = this->params.numTree; + const size_t numPos = this->params.numPos; + const size_t numNeg = this->params.numNeg; + const double fraction = this->params.fraction; + const size_t pSize = this->params.pSize; + const size_t gSize = this->params.gSize; + const size_t shrink = this->params.shrink; + const size_t rowSize = this->params.rowSize; + const size_t colSize = this->params.colSize; + // pRad = radius of Image patches. + // gRad = radius of ground truth patches. + const size_t pRad = pSize / 2, gRad = gSize / 2; - arma::vec FtrDim; - this->GetFeatureDimension(FtrDim); - const size_t n_ftr_dim = FtrDim(0) + FtrDim(1); - const size_t n_smp_ftr_dim = (size_t)(n_ftr_dim * fraction); + arma::vec FtrDIm; + this->GetFeatureDImension(FtrDIm); + const size_t nFtrDIm = FtrDIm(0) + FtrDIm(1); + const size_t nSmpFtrDIm = (size_t)(nFtrDIm * fraction); - for(size_t i = 0; i < n_tree; ++i) + for(size_t i = 0; i < numTree; ++i) { - //implement the logic for if data already exists. - MatType ftrs = arma::zeros(n_pos + n_neg, n_smp_ftr_dim); + //Implement the logic for if data already exists. + MatType ftrs = arma::zeros(numPos + numNeg, nSmpFtrDIm); //effectively a 3d array. . . - MatType lbls = arma::zeros( g_size * g_size, (n_pos + n_neg )); + MatType lbls = arma::zeros( gSize * gSize, (numPos + numNeg )); // still to be done: store features and labels calculated // in the loop and store it in these Matrices. // Could use some suggestions for this. - size_t loop_iter = num_images * 5; - for(size_t j = 0; j < loop_iter; j += 5) + size_t loop_iter = num_Images; + for(size_t j = 0; j < loop_iter; ++j) { - MatType img, bnds, segs; - img = InputData.submat(j * row_size, 0, (j + 3) * row_size - 1, col_size - 1); - bnds = InputData.submat( (j + 3) * row_size, 0, \ - (j + 4) * row_size - 1, col_size - 1 ); - segs = InputData.submat( (j + 4) * row_size, 0, \ - (j + 5) * row_size - 1, col_size - 1 ); - - MatType mask = arma::zeros(row_size, col_size); - for(size_t b = 0; b < mask.n_cols; b = b + shrink) - for(size_t a = 0; a < mask.n_rows; a = a + shrink) - mask(a, b) = 1; - mask.col(p_rad - 1).fill(0); - mask.row( (mask.n_rows - 1) - (p_rad - 1) ).fill(0); - mask.submat(0, 0, mask.n_rows - 1, p_rad - 1).fill(0); - mask.submat(0, mask.n_cols - p_rad, mask.n_rows - 1, + MatType Img, bnds, segs; + Img = Images.submat(j * rowSize, 0, (j + 3) * rowSize - 1, colSize - 1); + bnds = Boundaries.submat( j * rowSize, 0, \ + j * rowSize - 1, colSize - 1 ); + segs = Segmentations.submat( j * rowSize, 0, \ + j * rowSize - 1, colSize - 1 ); + + MatType mask(rowSize, colSize, arma::fill::ones); + mask.col(pRad - 1).fill(0); + mask.row( (mask.n_rows - 1) - (pRad - 1) ).fill(0); + mask.submat(0, 0, mask.n_rows - 1, pRad - 1).fill(0); + mask.submat(0, mask.n_cols - pRad, mask.n_rows - 1, mask.n_cols - 1).fill(0); // number of positive or negative patches per ground truth. - //int n_patches_per_gt = (int) (ceil( (float)n_pos / num_images )); - const size_t n_patches_per_gt = 500; - //cout << "n_patches_per_gt = " << n_patches_per_gt << endl; + + const size_t nPatchesPerGt = 500; MatType dis; this->DistanceTransformImage(bnds, 1, dis) - dis = arma::sqrt(dis); - //dis.transform( [](double val, const int& g_rad) { return (double)(val < g_rad); } ); - //dis2.transform( [](double val, const int& g_rad) { return (double)(val >= g_rad); } ); - //dis.elem( arma::find(dis >= g_rad) ).zeros(); - //dis2.elem( arma::find(dis < g_rad) ).zeros(); - + dis = arma::sqrt(dis); - arma::uvec pos_loc = arma::find( (dis < g_rad) % mask ); - arma::uvec neg_loc = arma::find( (dis >= g_rad) % mask ); + arma::uvec posLoc = arma::find( (dis < gRad) % mask ); + arma::uvec negLoc = arma::find( (dis >= gRad) % mask ); - pos_loc = arma::shuffle(pos_loc); - neg_loc = arma::shuffle(neg_loc); + posLoc = arma::shuffle(posLoc); + negLoc = arma::shuffle(negLoc); - arma::umat loc(n_patches_per_gt * 2, 2); - //cout << "pos_loc size: " << arma::size(pos_loc) << " neg_loc size: " << arma::size(neg_loc) << endl; - //cout << "n_patches_per_gt = " << n_patches_per_gt << endl; - for(size_t i = 0; i < n_patches_per_gt; ++i) + arma::umat loc(nPatchesPerGt * 2, 2); + + for(size_t i = 0; i < nPatchesPerGt; ++i) { - loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), pos_loc(i) ).t(); - //cout << "pos_loc: " << loc(i, 0) << ", " << loc(i, 1) << endl; + loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), posLoc(i) ).t(); + //cout << "posLoc: " << loc(i, 0) << ", " << loc(i, 1) << endl; } - for(size_t i = n_patches_per_gt; i < 2 * n_patches_per_gt; ++i) + for(size_t i = nPatchesPerGt; i < 2 * nPatchesPerGt; ++i) { - loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), neg_loc(i) ).t(); - //cout << "neg_loc: " << loc(i, 0) << ", " << loc(i, 1) << endl; + loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), negLoc(i - nPatchesPerGt) ).t(); } - // cout << "num patches = " << n_patches_per_gt << " num elements + = " << pos_loc.n_elem\ - // << " num elements - = " << neg_loc.n_elem << " dis.size " << dis.n_elem << endl; - CubeType SSFtr, RegFtr; - arma::field F = this->GetFeatures(img, loc); + this->GetFeatures(Img, loc, RegFtr, SSFtr); //randomly sample 70 values each from reg_ftr and ss_ftr. /* CubeType ftr(140, 1000, 13); @@ -895,22 +934,22 @@ PrepareData(MatType const &InputData) arma::uvec rs = r.shuffle(); arma::uvec ss = s.shuffle(); */ - MatType lbl(g_size * g_size, 1000); + MatType lbl(gSize * gSize, 1000); CubeType s(segs.n_rows, segs.n_cols, 1); // have to do this or we can overload the CopyMakeBorder to support MatType. s.slice(0) = segs; CubeType in_segs; - this->CopyMakeBorder(s, g_rad, - g_rad, g_rad, g_rad, in_segs); + this->CopyMakeBorder(s, gRad, + gRad, gRad, gRad, in_segs); for(size_t i = 0; i < loc.n_rows; ++i) { size_t x = loc(i, 0); size_t y = loc(i, 1); //cout << "x, y = " << x << " " << y << endl; lbl.col(i) = arma::vectorise(in_segs.slice(0)\ - .submat((x + g_rad) - g_rad, (y + g_rad) - g_rad,\ - (x + g_rad) + g_rad - 1, (y + g_rad) + g_rad - 1)); + .submat((x + gRad) - gRad, (y + gRad) - gRad,\ + (x + gRad) + gRad - 1, (y + gRad) + gRad - 1)); } } } @@ -939,7 +978,7 @@ Discretize(MatType const &labels, size_t n_class, size_t n_sample) for (size_t j = 0; j < zs.n_rows; ++i) zs(i, j) = (labels(i, z1(j)) == labels(i, z2(j))) ? 1 : 0; } - zs -= arma::mean(zs, 1); // calculate mean about cols. n_col = 256. + zs -= arma::mean(zs, 1); // calculate mean abOut cols. n_col = 256. if ( arma::find(zs).n_elem == 0 ) { labels.fill(ones); @@ -948,7 +987,7 @@ Discretize(MatType const &labels, size_t n_class, size_t n_sample) { //find most representative segs } - // discretize zs by discretizing pca dimensions + // discretize zs by discretizing pca dImensions size_t d = min(5, n_sample, (size_t)floor(log(n_class, 2))); zs = pca(); diff --git a/src/mlpack/methods/edge_boxes/feature_parameters.hpp b/src/mlpack/methods/edge_boxes/feature_parameters.hpp index 567fc6c1479..3cf1f7a26a7 100644 --- a/src/mlpack/methods/edge_boxes/feature_parameters.hpp +++ b/src/mlpack/methods/edge_boxes/feature_parameters.hpp @@ -65,6 +65,9 @@ class FeatureParameters void Fraction(double value) { fraction = value; } double Fraction() const { return fraction; } + + void NumTree(double value) { numTree = value; } + double NumTree() const { return numTree; } private: size_t numImages; @@ -77,12 +80,14 @@ class FeatureParameters size_t grdNormRad; size_t regSmoothRad; size_t ssSmoothRad; + double fraction; size_t pSize; size_t gSize; size_t numCell; size_t numPos; size_t numNeg; - double numCell; + size_t numCell; + size_t numTree; }; } From b506eea1d4f0e2d527c17349b35edcea9652c79a Mon Sep 17 00:00:00 2001 From: nilayjain Date: Mon, 20 Jun 2016 04:10:42 +0000 Subject: [PATCH 04/14] added discretize function --- .../methods/edge_boxes/edge_boxes_main.cpp | 60 ++- .../methods/edge_boxes/feature_extraction.hpp | 72 +-- .../edge_boxes/feature_extraction_impl.hpp | 409 ++++++++++-------- .../methods/edge_boxes/feature_parameters.hpp | 5 +- src/mlpack/tests/edge_boxes_test.cpp | 87 ++-- src/mlpack/tests/pca_test.cpp | 20 +- 6 files changed, 359 insertions(+), 294 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp index b8bb0a79a97..4afa4cbc958 100644 --- a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp +++ b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp @@ -44,48 +44,38 @@ int main() nms: if true apply non-maximum suppression to edges */ - map options; - options["num_images"] = 2; - options["row_size"] = 321; - options["col_size"] = 481; - options["rgbd"] = 0; - options["shrink"] = 2; - options["n_orient"] = 4; - options["grd_smooth_rad"] = 0; - options["grd_norm_rad"] = 4; - options["reg_smooth_rad"] = 2; - options["ss_smooth_rad"] = 8; - options["p_size"] = 32; - options["g_size"] = 16; - options["n_cell"] = 5; + FeatureParameters params = FeatureParameters(); - options["n_pos"] = 10000; - options["n_neg"] = 10000; - //options["fraction"] = 0.25; - options["n_tree"] = 8; - options["n_class"] = 2; - options["min_count"] = 1; - options["min_child"] = 8; - options["max_depth"] = 64; - options["split"] = 0; // we use 0 for gini, 1 for entropy, 2 for other - options["stride"] = 2; - options["sharpen"] = 2; - options["n_tree_eval"] = 4; - options["nms"] = 1; // 1 for true, 0 for false - - StructuredForests SF(options); + params.NumImages(2); + params.RowSize(321); + params.ColSize(481); + params.RGBD(0); + params.Shrink(2); + params.NumOrient(4); + params.GrdSmoothRad(0); + params.GrdNormRad(4); + params.RegSmoothRad(2); + params.SSSmoothRad(8); + params.Fraction(0.25); + params.PSize(32); + params.GSize(16); + params.NumCell(5); + params.NumPos(10000); + params.NumNeg(10000); + params.NumCell(5); + params.NumTree(8); + StructuredForests SF(params); // arma::uvec x(2); //SF.GetFeatureDimension(x); arma::mat segmentations, boundaries, images; - data::Load("/home/nilay/Desktop/GSoC/code/example/example/small_images.csv", images); - data::Load("/home/nilay/Desktop/GSoC/code/example/example/small_boundary_1.csv", boundaries); - data::Load("/home/nilay/Desktop/GSoC/code/example/example/small_segmentation_1.csv", segmentations); + data::Load("/home/nilay/example/small_images.csv", images); + data::Load("/home/nilay/example/small_boundary_1.csv", boundaries); + data::Load("/home/nilay/example/small_segmentation_1.csv", segmentations); - arma::mat input_data = SF.LoadData(images, boundaries, segmentations); - cout << input_data.n_rows << " " << input_data.n_cols << endl; - SF.PrepareData(input_data); + SF.PrepareData(images, boundaries, segmentations); cout << "PrepareData done." << endl; return 0; } + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp index ed371f73c97..16fa8c1ad92 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -9,7 +9,7 @@ //#define INF 999999.9999 //#define EPS 1E-20 #include - +#include "feature_parameters.hpp" namespace mlpack { namespace structured_tree { @@ -22,16 +22,15 @@ class StructuredForests public: static constexpr double eps = 1e-20; - - std::map options; - - void StructuredForests(FeatureParameters F); + + StructuredForests(FeatureParameters F); /* MatType LoadData(MatType const &images, MatType const &boundaries,\ MatType const &segmentations);*/ - void PrepareData(MatType const &InputData); + void PrepareData(const MatType& Images, const MatType& Boundaries,\ + const MatType& Segmentations); - void GetFeatureDimension(arma::vec &FtrDim); + void GetFeatureDimension(arma::vec& FtrDim); void DistanceTransform1D(const arma::vec& f, const size_t n,\ const double inf, arma::vec& d); @@ -40,45 +39,55 @@ class StructuredForests void DistanceTransformImage(const MatType& Im, double on, MatType& Out); - void GetFeatures(const MatType &Image, arma::umat &loc, - CubeType& RegFtr, CubeType& SSFtr); + void GetFeatures(const MatType &Image, arma::umat &loc,\ + CubeType& RegFtr, CubeType& SSFtr,\ + const arma::vec& table); - CubeType CopyMakeBorder(CubeType const &InImage, size_t top, - size_t left, size_t bottom, size_t right); + void CopyMakeBorder(const CubeType& InImage, size_t top, + size_t left, size_t bottom, size_t right, + CubeType& OutImage); - void GetShrunkChannels(CubeType const &InImage, CubeType ®_ch, CubeType &ss_ch); + void GetShrunkChannels(const CubeType& InImage, CubeType ®_ch,\ + CubeType &ss_ch, const arma::vec& table); - CubeType RGB2LUV(CubeType const &InImage); + void RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ + const arma::vec& table); - MatType bilinearInterpolation(MatType const &src, - size_t height, size_t width); + void BilinearInterpolation(const MatType& src, + size_t height, size_t width, + MatType& dst); - CubeType sepFilter2D(CubeType &InOutImage, arma::vec &kernel,\ - size_t radius); + void SepFilter2D(CubeType &InOutImage, const arma::vec& kernel, const size_t radius); - CubeType ConvTriangle(CubeType &InImage, size_t radius); + void ConvTriangle(CubeType &InImage, const size_t radius); - void Gradient(CubeType const &InImage, - MatType &Magnitude, - MatType &Orientation); + void Gradient(const CubeType& InImage, + MatType& Magnitude, + MatType& Orientation); - MatType MaxAndLoc(CubeType &mag, arma::umat &Location) const; + void MaxAndLoc(CubeType &mag, arma::umat &Location, MatType& MaxVal) const; - CubeType Histogram(MatType const &Magnitude, - MatType const &Orientation, - size_t downscale, size_t interp); + void Histogram(const MatType& Magnitude, + const MatType& Orientation, + size_t downscale, size_t interp, + CubeType& HistArr); - CubeType ViewAsWindows(CubeType const &channels, arma::umat const &loc); + void ViewAsWindows(const CubeType& channels, const arma::umat& loc, + CubeType& features); - CubeType GetRegFtr(CubeType const &channels, arma::umat const &loc); + void GetRegFtr(const CubeType& channels, const arma::umat& loc, + CubeType& RegFtr); - CubeType GetSSFtr(CubeType const &channels, arma::umat const &loc); + void GetSSFtr(const CubeType& channels, const arma::umat& loc, + CubeType SSFtr); - CubeType Rearrange(CubeType const &channels); + void Rearrange(const CubeType& channels, CubeType& ch); - CubeType PDist(CubeType const &features, arma::uvec const &grid_pos); + void PDist(const CubeType& features, const arma::uvec& grid_pos, + CubeType& Output); - //void Discretize(MatType const &lbl, size_t n_class, size_t n_sample); + size_t Discretize(const MatType& labels, const size_t nClass,\ + const size_t nSample, arma::vec& DiscreteLabels); }; @@ -87,3 +96,4 @@ class StructuredForests #include "feature_extraction_impl.hpp" #endif + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index 5d6abb98a5c..c175906081e 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -9,8 +9,7 @@ #include "feature_extraction.hpp" -#include - +#include namespace mlpack { namespace structured_tree { @@ -22,10 +21,7 @@ template StructuredForests:: StructuredForests(FeatureParameters F) { - // to do. params = F; - //check if this works. - std::cout << params.numImages << std::endl; } /* @@ -35,8 +31,8 @@ LoadData(MatType const &Images, MatType const &boundaries,\ MatType const &segmentations) { const size_t num_Images = this->params.num_Images; - const size_t rowSize = this->params.rowSize; - const size_t colSize = this->params.colSize; + const size_t rowSize = this->params.RowSize(); + const size_t colSize = this->params.ColSize(); MatType input_data(num_Images * rowSize * 5, colSize); // we store the input data as follows: // Images (3), boundaries (1), segmentations (1). @@ -70,33 +66,33 @@ LoadData(MatType const &Images, MatType const &boundaries,\ } return input_data; } + */ /** * Get DImensions of Features - * @param FtrDIm Output vector that contains the result + * @param FtrDim Output vector that contains the result */ template void StructuredForests:: -GetFeatureDImension(arma::vec &FtrDIm) +GetFeatureDimension(arma::vec& FtrDim) { - FtrDIm = arma::vec(2); + FtrDim = arma::vec(2); - const size_t shrink = this->params.shrink; - const size_t pSize = this->params.pSize; - const size_t numCell = this->params.numCell; - const size_t rgbd = this->params.rgbd; - const size_t numOrient = this->params.numOrient; + const size_t shrink = this->params.Shrink(); + const size_t pSize = this->params.PSize(); + const size_t numCell = this->params.NumCell(); + const size_t numOrient = this->params.NumOrient(); size_t nColorCh; - if (this->params.rgbd == 0) + if (this->params.RGBD() == 0) nColorCh = 3; else nColorCh = 4; const size_t nCh = nColorCh + 2 * (1 + numOrient); - FtrDIm[0] = std::pow((pSize / shrink) , 2) * nCh; - FtrDIm[1] = std::pow(numCell , 2) * (std::pow (numCell, 2) - 1) / 2 * nCh; + FtrDim[0] = std::pow((pSize / shrink) , 2) * nCh; + FtrDim[1] = std::pow(numCell , 2) * (std::pow (numCell, 2) - 1) / 2 * nCh; } /** @@ -139,7 +135,6 @@ DistanceTransform1D(const arma::vec& f, const size_t n, const double inf, ++k; d[q] = (q - v[k]) * (q - v[k]) + f[v[k]]; } - return d; } /** @@ -184,27 +179,27 @@ DistanceTransformImage(const MatType& Im, double on, MatType& Out) { //need a large value but not infinity. double inf = 999999.99; - MatType Out = MatType(Im.n_rows, Im.n_cols, arma::fill::zeros); + Out = MatType(Im.n_rows, Im.n_cols, arma::fill::zeros); Out.elem( find(Im != on) ).fill(inf); this->DistanceTransform2D(Out, inf); } /** * Makes a reflective border around an Image. - * @param InImage Image which we have to make border around. - * @param top border length at top. - * @param left border length at left. - * @param bottom border length at bottom. - * @param right border length at right. - * @param OutImage Output Image. + * @param InImage, Image which we have to make border around. + * @param top, border length (to be incremented) at top. + * @param left, border length at left. + * @param bottom, border length at bottom. + * @param right, border length at right. + * @param OutImage, Output Image. */ template void StructuredForests:: CopyMakeBorder(const CubeType& InImage, size_t top, - size_t left, size_t bottom, size_t right + size_t left, size_t bottom, size_t right, CubeType& OutImage) { - OutImage = MatType(InImage.n_rows + top + bottom, InImage.n_cols + left + right, InImage.n_slices); + OutImage = CubeType(InImage.n_rows + top + bottom, InImage.n_cols + left + right, InImage.n_slices); for(size_t i = 0; i < InImage.n_slices; ++i) { @@ -241,35 +236,17 @@ CopyMakeBorder(const CubeType& InImage, size_t top, /** * Converts an Image in RGB color space to LUV color space. + * RGB must range in (0.0, 1.0). * @param InImage Input Image in RGB color space. * @param OutImage Ouptut Image in LUV color space. */ template void StructuredForests:: -RGB2LUV(const CubeType& InImage, CubeType OutImage) +RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ + const arma::vec& table) { //assert type is double or float. - double a, y0, maxi; - a = std::pow(29.0, 3) / 27.0; - y0 = 8.0 / a; - maxi = 1.0 / 270.0; - arma::vec table(1064); - for (size_t i = 0; i <= 1024; ++i) - { - table(i) = i / 1024.0; - - if (table(i) > y0) - table(i) = 116 * pow(table(i), 1.0/3.0) - 16.0; - else - table(i) = table(i) * a; - - table(i) = table(i) * maxi; - } - for(size_t i = 1025; i < table.n_elem; ++i) - { - table(i) = table(i - 1); - } MatType rgb2xyz; rgb2xyz << 0.430574 << 0.222015 << 0.020183 << arma::endr @@ -289,13 +266,11 @@ RGB2LUV(const CubeType& InImage, CubeType OutImage) xyz.slice(1)(i) = 0.222015 * r + 0.706655 * g + 0.071330 * b; xyz.slice(2)(i) = 0.020183 * r + 0.129553 * g + 0.939180 * b; } - MatType nz(InImage.n_rows, InImage.n_cols); nz = 1.0 / ( xyz.slice(0) + (15 * xyz.slice(1) ) + (3 * xyz.slice(2) + 1e-35)); - OutImage = MatType(InImage.n_rows, InImage.n_cols, InImage.n_slices); - + OutImage = CubeType(InImage.n_rows, InImage.n_cols, InImage.n_slices); for(size_t j = 0; j < xyz.n_cols; ++j) { for(size_t i = 0; i < xyz.n_rows; ++i) @@ -303,7 +278,7 @@ RGB2LUV(const CubeType& InImage, CubeType OutImage) OutImage(i, j, 0) = table( static_cast( (1024 * xyz(i, j, 1) ) ) ); } } - + double maxi = 1.0 / 270.0; OutImage.slice(1) = OutImage.slice(0) % (13 * 4 * (xyz.slice(0) % nz) \ - 13 * 0.197833) + 88 * maxi; OutImage.slice(2) = OutImage.slice(0) % (13 * 9 * (xyz.slice(1) % nz) \ @@ -320,9 +295,9 @@ RGB2LUV(const CubeType& InImage, CubeType OutImage) /*Implement this function in a column major order.*/ template void StructuredForests:: -BilinearInterpolation(MatType const &src, +BilinearInterpolation(const MatType& src, size_t height, size_t width, - MatType dst) + MatType& dst) { dst = MatType(height, width); double const x_ratio = static_cast((src.n_cols - 1)) / width; @@ -399,10 +374,10 @@ ConvTriangle(CubeType &InImage, const size_t radius) else if (radius <= 1) { const double p = 12.0 / radius / (radius + 2) - 2; - arma::vec kernel = {1 , p, 1}; + arma::vec kernel = {1, p, 1}; kernel /= (p + 2); - this->sepFilter2D(InImage, kernel, radius); + this->SepFilter2D(InImage, kernel, radius); } else { @@ -418,7 +393,7 @@ ConvTriangle(CubeType &InImage, const size_t radius) kernel(i) = r--; kernel /= std::pow(radius + 1, 2); - this->sepFilter2D(InImage, kernel, radius); + this->SepFilter2D(InImage, kernel, radius); } } @@ -427,7 +402,7 @@ ConvTriangle(CubeType &InImage, const size_t radius) // also stores the locations of max values in Location template void StructuredForests:: -MaxAndLoc(CubeType &mag, arma::umat &Location, CubeType& MaxVal) const +MaxAndLoc(CubeType& mag, arma::umat& Location, MatType& MaxVal) const { /*Vectorize this function after prototype works*/ MaxVal = MatType(Location.n_rows, Location.n_cols); @@ -436,7 +411,7 @@ MaxAndLoc(CubeType &mag, arma::umat &Location, CubeType& MaxVal) const for(size_t j = 0; j < mag.n_cols; ++j) { /*can use -infinity here*/ - double max = std::numeric_lImits::min(); + double max = std::numeric_limits::min(); for(size_t k = 0; k < mag.n_slices; ++k) { if(mag(i, j, k) > max) @@ -450,13 +425,16 @@ MaxAndLoc(CubeType &mag, arma::umat &Location, CubeType& MaxVal) const } } +/** + * Computes Gradient, Magnitude & Orientation. + */ template void StructuredForests:: -Gradient(const CubeType &InImage, - MatType &Magnitude, - MatType &Orientation) +Gradient(const CubeType& InImage, + MatType& Magnitude, + MatType& Orientation) { - const size_t grdNormRad = this->params.grdNormRad; + const size_t grdNormRad = this->params.GrdNormRad(); CubeType dx(InImage.n_rows, InImage.n_cols, InImage.n_slices), dy(InImage.n_rows, InImage.n_cols, InImage.n_slices); @@ -516,8 +494,7 @@ Gradient(const CubeType &InImage, this->MaxAndLoc(mag, Location, Magnitude); if(grdNormRad != 0) { - //we have to do this ugly thing, or override ConvTriangle - // and sepFilter2D methods. + //we have to do this or override ConvTriangle and SepFilter2D methods. CubeType mag2(InImage.n_rows, InImage.n_cols, 1); mag2.slice(0) = Magnitude; this->ConvTriangle(mag2, grdNormRad); @@ -535,7 +512,10 @@ Gradient(const CubeType &InImage, } } Orientation = arma::atan(dy_mat / dx_mat); - Orientation.transform( [](double val) { if(val < 0) return (val + arma::datum::pi); else return (val);} ); + + Orientation.transform( [](double val)\ + { if(val < 0) return (val + arma::datum::pi);\ + else return (val);} ); for(size_t j = 0; j < InImage.n_cols; ++j) { @@ -554,13 +534,12 @@ Histogram(const MatType& Magnitude, size_t downscale, size_t interp, CubeType& HistArr) { - //i don't think this function can be vectorized. //numOrient: number of orientations per gradient scale - const size_t numOrient = this->params.numOrient; - //size of HistArr: n_rbin * n_cbin * numOrient . . . (create in caller...) + const size_t numOrient = this->params.NumOrient(); const size_t n_rbin = (Magnitude.n_rows + downscale - 1) / downscale; const size_t n_cbin = (Magnitude.n_cols + downscale - 1) / downscale; + double o_range, o; o_range = arma::datum::pi / numOrient; @@ -597,32 +576,37 @@ Histogram(const MatType& Magnitude, HistArr.slice(i) = arma::square(HistArr.slice(i)); } +/** + * Shrink the size of Image by shrink size. + * Change color space of Image. + * Extract candidate features. + * @param InImage, Input Image. + * @param regCh, + * @param ssCh, + */ + template void StructuredForests:: -GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch) +GetShrunkChannels(const CubeType& InImage, CubeType& reg_ch,\ + CubeType& ss_ch, const arma::vec& table) { CubeType luv; - this->RGB2LUV(InImage, luv); - - const size_t shrink = this->params.shrink; - const size_t numOrient = this->params.numOrient; - const size_t grdSmoothRad = this->params.grdSmoothRad; - const size_t grdNormRad = this->params.grdNormRad; - const size_t num_channels = 13; + this->RGB2LUV(InImage, luv, table); + const size_t shrink = this->params.Shrink(); + const size_t grdSmoothRad = this->params.GrdSmoothRad(); + const size_t numChannels = 13; const size_t rsize = luv.n_rows / shrink; const size_t csize = luv.n_cols / shrink; - CubeType channels(rsize, csize, num_channels); - - - size_t slice_idx = 0; - + + CubeType channels(rsize, csize, numChannels); + + size_t slice_idx; for( slice_idx = 0; slice_idx < luv.n_slices; ++slice_idx) - this->BilinearInterpolation(luv.slice(slice_idx), (size_t)rsize, (size_t)csize + this->BilinearInterpolation(luv.slice(slice_idx), rsize, csize, channels.slice(slice_idx)); - - double scale = 0.5; + double scale = 1.0; - while(scale <= 1.0) + while(scale >= 0.5) { CubeType Img( (luv.n_rows * scale), (luv.n_cols * scale), @@ -631,76 +615,72 @@ GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch) for( slice_idx = 0; slice_idx < luv.n_slices; ++slice_idx) { this->BilinearInterpolation(luv.slice(slice_idx), - (luv.n_rows * scale), (luv.n_cols * scale) + (luv.n_rows * scale), (luv.n_cols * scale), Img.slice(slice_idx)); } - CubeType OutImage = this->ConvTriangle(Img, grdSmoothRad); - + this->ConvTriangle(Img, grdSmoothRad); MatType Magnitude(InImage.n_rows, InImage.n_cols), Orientation(InImage.n_rows, InImage.n_cols); - this->Gradient(OutImage, Magnitude, Orientation); - + this->Gradient(Img, Magnitude, Orientation); size_t downscale = std::max(1, (int)(shrink * scale)); - CubeType Hist = this->Histogram(Magnitude, Orientation, - downscale, 0); - + CubeType Hist; + this->Histogram(Magnitude, Orientation, + downscale, 0, Hist); BilinearInterpolation( Magnitude, rsize, csize, channels.slice(slice_idx)); slice_idx++; for(size_t i = 0; i < InImage.n_slices; ++i) BilinearInterpolation( Magnitude, rsize, csize,\ channels.slice(i + slice_idx)); slice_idx += 3; - scale += 0.5; + scale -= 0.5; } //cout << "size of channels: " << arma::size(channels) << endl; double regSmoothRad, ssSmoothRad; - regSmoothRad = this->params.regSmoothRad / (double) shrink; - ssSmoothRad = this->params.ssSmoothRad / (double) shrink; - + regSmoothRad = this->params.RegSmoothRad() / (double) shrink; + ssSmoothRad = this->params.SSSmoothRad() / (double) shrink; + reg_ch = channels; + ss_ch = channels; if (regSmoothRad > 1.0) - reg_ch = this->ConvTriangle(channels, (size_t) (std::round(regSmoothRad)) ); + this->ConvTriangle(channels, (size_t) (std::round(regSmoothRad)) ); else - reg_ch = this->ConvTriangle(channels, regSmoothRad); + this->ConvTriangle(channels, regSmoothRad); if (ssSmoothRad > 1.0) - ss_ch = this->ConvTriangle(channels, (size_t) (std::round(ssSmoothRad)) ); + this->ConvTriangle(channels, (size_t) (std::round(ssSmoothRad)) ); else - ss_ch = this->ConvTriangle(channels, ssSmoothRad); + this->ConvTriangle(channels, ssSmoothRad); } template void StructuredForests:: -ViewAsWindows(const CubeType& channels, arma::umat const &loc, +ViewAsWindows(const CubeType& channels, const arma::umat& loc, CubeType& features) { // 500 for posLoc, and 500 for negLoc. // channels = 160, 240, 13. - features = CubeType(16, 16, 1000 * 13); + features = CubeType(16, 16, loc.n_rows * 13); const size_t patchSize = 16; const size_t p = patchSize / 2; //increase the channel boundary to protect error against Image boundaries. - CubeType inc_ch; - this->CopyMakeBorder(channels, p, p, p, p, inc_ch); + CubeType incCh; + this->CopyMakeBorder(channels, p, p, p, p, incCh); for (size_t i = 0, channel = 0; i < loc.n_rows; ++i) { size_t x = loc(i, 0); size_t y = loc(i, 1); - /*(x,y) in channels, is ((x+p), (y+p)) in inc_ch*/ - CubeType patch = inc_ch.tube((x + p) - p, (y + p) - p,\ + /*(x,y) in channels, is ((x+p), (y+p)) in incCh*/ + CubeType patch = incCh.tube((x + p) - p, (y + p) - p,\ (x + p) + p - 1, (y + p) + p - 1); - // since each patch has 13 channel we have to increase the index by 13 - - //cout <<"patch size = " << arma::size(patch) << endl; - + // since each patch has 13 channel we have to increase the index by 13 features.slices(channel, channel + 12) = patch; channel += 13; } @@ -708,7 +688,7 @@ ViewAsWindows(const CubeType& channels, arma::umat const &loc, template void StructuredForests:: -Rearrange(CubeType const &channels, CubeType& ch) +Rearrange(const CubeType& channels, CubeType& ch) { //we do (16,16,13*1000) to 256, 1000, 13, in vectorized code. ch = CubeType(256, 1000, 13); @@ -727,18 +707,17 @@ Rearrange(CubeType const &channels, CubeType& ch) // returns 256 * 1000 * 13 dImension features. template void StructuredForests:: -GetRegFtr(const CubeType& channels,const arma::umat& loc +GetRegFtr(const CubeType& channels, const arma::umat& loc, CubeType& RegFtr) { - int shrink = this->params.shrink; - int pSize = this->params.pSize / shrink; +// int pSize = this->params.PSize() / shrink; CubeType wind; this->ViewAsWindows(channels, loc, wind); this->Rearrange(wind, RegFtr); } template -CubeType StructuredForests:: +void StructuredForests:: PDist(const CubeType& features, const arma::uvec& grid_pos, CubeType& Output) { @@ -767,14 +746,14 @@ PDist(const CubeType& features, const arma::uvec& grid_pos, //returns 300,1000,13 dImension features. template void StructuredForests:: -GetSSFtr(CubeType const &channels, arma::umat const &loc +GetSSFtr(const CubeType& channels, const arma::umat& loc, CubeType SSFtr) { - const size_t shrink = this->params.shrink; - const size_t pSize = this->params.pSize / shrink; + const size_t shrink = this->params.Shrink(); + const size_t pSize = this->params.PSize() / shrink; //numCell: number of self sImilarity cells - const size_t numCell = this->params.numCell; + const size_t numCell = this->params.NumCell(); const size_t half_cell_size = (size_t) round(pSize / (2.0 * numCell)); arma::uvec g_pos(numCell); @@ -802,12 +781,12 @@ GetSSFtr(CubeType const &channels, arma::umat const &loc } template -void StructuredForests:: +void StructuredForests:: GetFeatures(const MatType &Image, arma::umat &loc, - CubeType& RegFtr, CubeType& SSFtr) + CubeType& RegFtr, CubeType& SSFtr, const arma::vec& table) { - const size_t rowSize = this->params.rowSize; - const size_t colSize = this->params.colSize; + const size_t rowSize = this->params.RowSize(); + const size_t colSize = this->params.ColSize(); const size_t bottom = (4 - (Image.n_rows / 3) % 4) % 4; const size_t right = (4 - Image.n_cols % 4) % 4; //cout << "Botttom = " << bottom << " right = " << right << endl; @@ -823,19 +802,21 @@ GetFeatures(const MatType &Image, arma::umat &loc, CubeType OutImage; this->CopyMakeBorder(InImage, 0, 0, bottom, right, OutImage); - const size_t num_channels = 13; - const size_t shrink = this->params.shrink; + + const size_t numChannels = 13; + const size_t shrink = this->params.Shrink(); const size_t rsize = OutImage.n_rows / shrink; const size_t csize = OutImage.n_cols / shrink; /* this part gives double free or corruption Out error when executed for a second tIme */ - CubeType reg_ch = CubeType(rsize, csize, num_channels); - CubeType ss_ch = CubeType(rsize, csize, num_channels); - this->GetShrunkChannels(InImage, reg_ch, ss_ch); - - loc /= shrink; + CubeType reg_ch = CubeType(rsize, csize, numChannels); + CubeType ss_ch = CubeType(rsize, csize, numChannels); + + this->GetShrunkChannels(InImage, reg_ch, ss_ch, table); + + loc /= shrink; this->GetRegFtr(reg_ch, loc, RegFtr); this->GetSSFtr(ss_ch, loc, SSFtr); } @@ -851,46 +832,76 @@ void StructuredForests:: PrepareData(const MatType& Images, const MatType& Boundaries,\ const MatType& Segmentations) { - const size_t numImages = this->params.numImages; - const size_t numTree = this->params.numTree; - const size_t numPos = this->params.numPos; - const size_t numNeg = this->params.numNeg; - const double fraction = this->params.fraction; - const size_t pSize = this->params.pSize; - const size_t gSize = this->params.gSize; - const size_t shrink = this->params.shrink; - const size_t rowSize = this->params.rowSize; - const size_t colSize = this->params.colSize; + const size_t numImages = this->params.NumImages(); + const size_t numTree = this->params.NumTree(); + const size_t numPos = this->params.NumPos(); + const size_t numNeg = this->params.NumNeg(); + const double fraction = this->params.Fraction(); + const size_t pSize = this->params.PSize(); + const size_t gSize = this->params.GSize(); + const size_t rowSize = this->params.RowSize(); + const size_t colSize = this->params.ColSize(); // pRad = radius of Image patches. // gRad = radius of ground truth patches. const size_t pRad = pSize / 2, gRad = gSize / 2; - - arma::vec FtrDIm; - this->GetFeatureDImension(FtrDIm); - const size_t nFtrDIm = FtrDIm(0) + FtrDIm(1); - const size_t nSmpFtrDIm = (size_t)(nFtrDIm * fraction); + arma::vec FtrDim; + this->GetFeatureDimension(FtrDim); + const size_t nFtrDim = FtrDim(0) + FtrDim(1); + const size_t nSmpFtrDim = (size_t)(nFtrDim * fraction); + for(size_t i = 0; i < numTree; ++i) { //Implement the logic for if data already exists. - MatType ftrs = arma::zeros(numPos + numNeg, nSmpFtrDIm); + MatType ftrs = arma::zeros(numPos + numNeg, nSmpFtrDim); //effectively a 3d array. . . - MatType lbls = arma::zeros( gSize * gSize, (numPos + numNeg )); + MatType lbls = arma::zeros((numPos + numNeg ), gSize * gSize); // still to be done: store features and labels calculated // in the loop and store it in these Matrices. // Could use some suggestions for this. - size_t loop_iter = num_Images; + size_t loop_iter = numImages; + + // a vector which helps in converting Image from RGB2LUV. + double a, y0, maxi; + a = std::pow(29.0, 3) / 27.0; + y0 = 8.0 / a; + maxi = 1.0 / 270.0; + arma::vec table(1064); + + for (size_t i = 0; i <= 1024; ++i) + { + table(i) = i / 1024.0; + + if (table(i) > y0) + table(i) = 116 * pow(table(i), 1.0/3.0) - 16.0; + else + table(i) = table(i) * a; + + table(i) = table(i) * maxi; + } + + for(size_t i = 1025; i < table.n_elem; ++i) + table(i) = table(i - 1); + + size_t col_i = 0, col_s = 0, col_b = 0; for(size_t j = 0; j < loop_iter; ++j) { MatType Img, bnds, segs; - Img = Images.submat(j * rowSize, 0, (j + 3) * rowSize - 1, colSize - 1); - bnds = Boundaries.submat( j * rowSize, 0, \ + Img = MatType(Images.colptr(col_i), colSize, rowSize * 3).t() / 255; + col_i += 3; + //Img = Images.submat((j * 3) * rowSize, 0, ((j * 3) + 3) * rowSize - 1, colSize - 1); + //bnds = Boundaries.submat( j * rowSize, 0, \ j * rowSize - 1, colSize - 1 ); - segs = Segmentations.submat( j * rowSize, 0, \ + + bnds = MatType(Boundaries.colptr(col_b), colSize, rowSize).t(); + col_b++; + //segs = Segmentations.submat( j * rowSize, 0, \ j * rowSize - 1, colSize - 1 ); + segs = MatType(Segmentations.colptr(col_s), colSize, rowSize).t(); + col_s++; MatType mask(rowSize, colSize, arma::fill::ones); mask.col(pRad - 1).fill(0); mask.row( (mask.n_rows - 1) - (pRad - 1) ).fill(0); @@ -902,30 +913,31 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ const size_t nPatchesPerGt = 500; MatType dis; - this->DistanceTransformImage(bnds, 1, dis) + this->DistanceTransformImage(bnds, 1, dis); dis = arma::sqrt(dis); - arma::uvec posLoc = arma::find( (dis < gRad) % mask ); arma::uvec negLoc = arma::find( (dis >= gRad) % mask ); posLoc = arma::shuffle(posLoc); negLoc = arma::shuffle(negLoc); - arma::umat loc(nPatchesPerGt * 2, 2); + size_t lenLoc = std::min((int) negLoc.n_elem, std::min((int) nPatchesPerGt,\ + (int) posLoc.n_elem)); + arma::umat loc(lenLoc * 2, 2); - for(size_t i = 0; i < nPatchesPerGt; ++i) + for(size_t i = 0; i < lenLoc; ++i) { loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), posLoc(i) ).t(); //cout << "posLoc: " << loc(i, 0) << ", " << loc(i, 1) << endl; } - for(size_t i = nPatchesPerGt; i < 2 * nPatchesPerGt; ++i) + for(size_t i = lenLoc; i < 2 * lenLoc; ++i) { - loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), negLoc(i - nPatchesPerGt) ).t(); + loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), negLoc(i - lenLoc) ).t(); } CubeType SSFtr, RegFtr; - this->GetFeatures(Img, loc, RegFtr, SSFtr); + this->GetFeatures(Img, loc, RegFtr, SSFtr, table); //randomly sample 70 values each from reg_ftr and ss_ftr. /* CubeType ftr(140, 1000, 13); @@ -934,64 +946,81 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ arma::uvec rs = r.shuffle(); arma::uvec ss = s.shuffle(); */ - MatType lbl(gSize * gSize, 1000); + //MatType lbl(1000, gSize * gSize); CubeType s(segs.n_rows, segs.n_cols, 1); // have to do this or we can overload the CopyMakeBorder to support MatType. s.slice(0) = segs; CubeType in_segs; - this->CopyMakeBorder(s, gRad, - gRad, gRad, gRad, in_segs); + this->CopyMakeBorder(s, gRad, gRad, gRad, + gRad, in_segs); for(size_t i = 0; i < loc.n_rows; ++i) { size_t x = loc(i, 0); size_t y = loc(i, 1); - //cout << "x, y = " << x << " " << y << endl; - lbl.col(i) = arma::vectorise(in_segs.slice(0)\ + //std::cout << "x, y = " << x << " " << y << std::endl; + lbls.row(i) = arma::vectorise(in_segs.slice(0)\ .submat((x + gRad) - gRad, (y + gRad) - gRad,\ - (x + gRad) + gRad - 1, (y + gRad) + gRad - 1)); + (x + gRad) + gRad - 1, (y + gRad) + gRad - 1)).t(); } } + arma::vec DiscreteLabels; + size_t x = Discretize(lbls, 2, 256, DiscreteLabels); } } -/* + +// returns the index of the most representative label, and discretizes structured +// label to discreet classes in matrix subLbls. (this is a vector if nClass = 2) template -void StructuredForests:: -Discretize(MatType const &labels, size_t n_class, size_t n_sample) +size_t StructuredForests:: +Discretize(const MatType& labels, const size_t nClass,\ + const size_t nSample, arma::vec& DiscreteLabels) { // Map labels to discrete class labels. - // lbls : 256 * 20000. - // n_sample: number of samples for clustering structured labels 256 + // lbls : 20000 * 256. + // nSample: number of samples for clustering structured labels 256 + // nClass: number of classes (clusters) for binary splits. 2 - // see the return type. - arma::uvec lis1(n_sample); - - MatType zs(n_sample, lbls.n_cols); + arma::uvec lis1(nSample); for (size_t i = 0; i < lis1.n_elem; ++i) lis1(i) = i; - MatType DiscreteLabels = arma::zeros(n_sample, n); - for (size_t i = 0; i < labels.n_cols; ++i) + MatType zs(labels.n_rows, nSample); + // no. of principal components to keep. + size_t dim = std::min( 5, std::min( (int)nSample,\ + (int)std::floor( std::log2( (int)nClass ) ) ) ); + DiscreteLabels = arma::zeros(labels.n_rows, dim); + for (size_t j = 0; j < zs.n_cols; ++j) { - arma::uvec z1 = lis1.shuffle(); - arma::uvec z2 = lis2.shuffle(); - for (size_t j = 0; j < zs.n_rows; ++i) + arma::uvec z1 = arma::shuffle(lis1); + arma::uvec z2 = arma::shuffle(lis1); + for (size_t i = 0; i < zs.n_rows; ++i) zs(i, j) = (labels(i, z1(j)) == labels(i, z2(j))) ? 1 : 0; } - zs -= arma::mean(zs, 1); // calculate mean abOut cols. n_col = 256. - if ( arma::find(zs).n_elem == 0 ) + for (size_t i = 0; i < zs.n_cols; ++i) + zs.row(i) -= arma::mean(zs, 0); // calculate mean about rows. n_rows = 20000. + size_t ind = 0; + arma::uvec k = arma::find(zs > 0); + if ( k.n_elem == 0) { - labels.fill(ones); + DiscreteLabels.ones(); } else { - //find most representative segs + //find most representative label (closest to mean) + ind = arma::sum(arma::abs(zs), 0).index_min(); + // so most representative label is: labels.row(ind). + + // apply pca + MatType coeff, transformedData; + arma::vec eigVal; + mlpack::pca::PCA p; + p.Apply(zs.t(), transformedData, eigVal, coeff); + // we take only first row in transformedData (256 * 20000) as dim = 1. + DiscreteLabels = arma::conv_to::from(transformedData.row(0).t() > 0); } - // discretize zs by discretizing pca dImensions - size_t d = min(5, n_sample, (size_t)floor(log(n_class, 2))); - zs = pca(); - -}*/ + return ind; +} } // namespace structured_tree } // namespace mlpack #endif diff --git a/src/mlpack/methods/edge_boxes/feature_parameters.hpp b/src/mlpack/methods/edge_boxes/feature_parameters.hpp index 3cf1f7a26a7..895ab878adf 100644 --- a/src/mlpack/methods/edge_boxes/feature_parameters.hpp +++ b/src/mlpack/methods/edge_boxes/feature_parameters.hpp @@ -16,7 +16,7 @@ class FeatureParameters { public: - FeatureParameters(); //default constructor + FeatureParameters(){} //default constructor void NumImages(size_t value) { numImages = value; } size_t NumImages() const { return numImages; } @@ -83,7 +83,6 @@ class FeatureParameters double fraction; size_t pSize; size_t gSize; - size_t numCell; size_t numPos; size_t numNeg; size_t numCell; @@ -92,3 +91,5 @@ class FeatureParameters } } +#include "feature_extraction.hpp" +#endif diff --git a/src/mlpack/tests/edge_boxes_test.cpp b/src/mlpack/tests/edge_boxes_test.cpp index 0d699a0034b..94b31c0509b 100644 --- a/src/mlpack/tests/edge_boxes_test.cpp +++ b/src/mlpack/tests/edge_boxes_test.cpp @@ -6,7 +6,6 @@ */ #include -#include #include #include @@ -66,7 +65,8 @@ void DistanceTransformTest(arma::mat& input, double on, arma::mat& output, StructuredForests& SF) { - arma::mat dt_output = SF.DistanceTransformImage(input, on); + arma::mat dt_output; + SF.DistanceTransformImage(input, on, dt_output); Test(dt_output, output); } @@ -74,53 +74,71 @@ void CopyMakeBorderTest(arma::cube& input, arma::cube& output, StructuredForests& SF) { - arma::cube border_output = SF.CopyMakeBorder(input, 1, 1, 1, 1); + arma::cube border_output; + SF.CopyMakeBorder(input, 1, 1, 1, 1, border_output); Test(border_output, output); } void RGB2LUVTest(arma::cube& input, arma::cube& output, StructuredForests& SF) { - arma::cube luv = SF.RGB2LUV(input); + double a, y0, maxi; + a = std::pow(29.0, 3) / 27.0; + y0 = 8.0 / a; + maxi = 1.0 / 270.0; + arma::vec table(1064); + + for (size_t i = 0; i <= 1024; ++i) + { + table(i) = i / 1024.0; + + if (table(i) > y0) + table(i) = 116 * pow(table(i), 1.0/3.0) - 16.0; + else + table(i) = table(i) * a; + + table(i) = table(i) * maxi; + } + + for(size_t i = 1025; i < table.n_elem; ++i) + table(i) = table(i - 1); + + arma::cube luv; + SF.RGB2LUV(input, luv, table); Test(luv, output); } void ConvTriangleTest(arma::cube& input, int radius, arma::cube& output, StructuredForests& SF) { - arma::cube conv_out = SF.ConvTriangle(input, radius); - Test(conv_out, output); + SF.ConvTriangle(input, radius); + Test(input, output); } BOOST_AUTO_TEST_CASE(FeatureExtractionTest) { - std::map options; - options["num_images"] = 2; - options["row_size"] = 321; - options["col_size"] = 481; - options["rgbd"] = 0; - options["shrink"] = 2; - options["n_orient"] = 4; - options["grd_smooth_rad"] = 0; - options["grd_norm_rad"] = 4; - options["reg_smooth_rad"] = 2; - options["ss_smooth_rad"] = 8; - options["p_size"] = 32; - options["g_size"] = 16; - options["n_cell"] = 5; - - options["n_pos"] = 10000; - options["n_neg"] = 10000; - options["n_tree"] = 8; - options["n_class"] = 2; - options["min_count"] = 1; - options["min_child"] = 8; - options["max_depth"] = 64; - options["split"] = 0; // we use 0 for gini, 1 for entropy, 2 for other - options["stride"] = 2; - options["sharpen"] = 2; - options["n_tree_eval"] = 4; - options["nms"] = 1; // 1 for true, 0 for false + FeatureParameters params = FeatureParameters(); + + params.NumImages(2); + params.RowSize(321); + params.ColSize(481); + params.RGBD(0); + params.Shrink(2); + params.NumOrient(4); + params.GrdSmoothRad(0); + params.GrdNormRad(4); + params.RegSmoothRad(2); + params.SSSmoothRad(8); + params.Fraction(0.25); + params.PSize(32); + params.GSize(16); + params.NumCell(5); + params.NumPos(10000); + params.NumNeg(10000); + params.NumCell(5); + params.NumTree(8); + + StructuredForests SF(params); arma::mat input, output; input << 0 << 0 << 0 << arma::endr @@ -130,7 +148,7 @@ BOOST_AUTO_TEST_CASE(FeatureExtractionTest) output << 2 << 1 << 2 << arma::endr << 1 << 0 << 1 << arma::endr << 0 << 1 << 2; - StructuredForests SF(options); + DistanceTransformTest(input, 1, output, SF); arma::cube in1(input.n_rows, input.n_cols, 1); @@ -185,3 +203,4 @@ out_luv.slice(2) << 0.496295 << 0.496295 << 0.496295 << arma::endr BOOST_AUTO_TEST_SUITE_END(); + diff --git a/src/mlpack/tests/pca_test.cpp b/src/mlpack/tests/pca_test.cpp index d7a78c8830a..2dc6b071b06 100644 --- a/src/mlpack/tests/pca_test.cpp +++ b/src/mlpack/tests/pca_test.cpp @@ -23,7 +23,7 @@ using namespace mlpack::distribution; */ BOOST_AUTO_TEST_CASE(ArmaComparisonPCATest) { - mat coeff, coeff1; + /*mat coeff, coeff1; vec eigVal, eigVal1; mat score, score1; @@ -41,7 +41,22 @@ BOOST_AUTO_TEST_CASE(ArmaComparisonPCATest) BOOST_REQUIRE_SMALL(eigVal1[i], 1e-15); else BOOST_REQUIRE_CLOSE(eigVal[i], eigVal1[i], 0.0001); - } + }*/ + + mat coeff, coeff1; + vec eigVal, eigVal1; + mat score, score1; + + mat data = randu(20000, 256); + + PCA p; + + p.Apply(data.t(), score1, eigVal1, coeff1); + cout << size(data.t()) << endl; + cout << "_____________" << endl; + cout << size(eigVal1) << endl; + cout << "______________" << endl; + cout << size(score1) << endl; } /** @@ -199,3 +214,4 @@ BOOST_AUTO_TEST_CASE(PCAScalingTest) BOOST_AUTO_TEST_SUITE_END(); + From f788f9761657892298fb87448fc6a42d11cdd56b Mon Sep 17 00:00:00 2001 From: Jain Date: Mon, 20 Jun 2016 11:08:40 +0530 Subject: [PATCH 05/14] added discretize function --- .../edge_boxes/feature_extraction_impl.hpp | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index c175906081e..7cd200630ae 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -24,51 +24,6 @@ StructuredForests(FeatureParameters F) params = F; } -/* -template -MatType StructuredForests:: -LoadData(MatType const &Images, MatType const &boundaries,\ - MatType const &segmentations) -{ - const size_t num_Images = this->params.num_Images; - const size_t rowSize = this->params.RowSize(); - const size_t colSize = this->params.ColSize(); - MatType input_data(num_Images * rowSize * 5, colSize); - // we store the input data as follows: - // Images (3), boundaries (1), segmentations (1). - size_t loop_iter = num_Images * 5; - size_t row_idx = 0; - size_t col_i = 0, col_s = 0, col_b = 0; - for(size_t i = 0; i < loop_iter; ++i) - { - if (i % 5 == 4) - { - input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ - colSize - 1) = MatType(segmentations.colptr(col_s),\ - colSize, rowSize).t(); - ++col_s; - } - else if (i % 5 == 3) - { - input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ - colSize - 1) = MatType(boundaries.colptr(col_b),\ - colSize, rowSize).t(); - ++col_b; - } - else - { - input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ - colSize - 1) = MatType(Images.colptr(col_i), - colSize, rowSize).t(); - ++col_i; - } - row_idx += rowSize; - } - return input_data; -} - -*/ - /** * Get DImensions of Features * @param FtrDim Output vector that contains the result From 49c25fe8186a73c4a62e887ae7b8e2c463cd43c5 Mon Sep 17 00:00:00 2001 From: Jain Date: Mon, 20 Jun 2016 11:53:34 +0530 Subject: [PATCH 06/14] fixed unintended changes --- .../edge_boxes/feature_extraction_impl.hpp | 2 +- src/mlpack/tests/pca_test.cpp | 19 ++----------------- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index 7cd200630ae..b6d2ba9a3d7 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -25,7 +25,7 @@ StructuredForests(FeatureParameters F) } /** - * Get DImensions of Features + * Get Dimensions of Features * @param FtrDim Output vector that contains the result */ template diff --git a/src/mlpack/tests/pca_test.cpp b/src/mlpack/tests/pca_test.cpp index 2dc6b071b06..4ee9b006763 100644 --- a/src/mlpack/tests/pca_test.cpp +++ b/src/mlpack/tests/pca_test.cpp @@ -23,7 +23,7 @@ using namespace mlpack::distribution; */ BOOST_AUTO_TEST_CASE(ArmaComparisonPCATest) { - /*mat coeff, coeff1; + mat coeff, coeff1; vec eigVal, eigVal1; mat score, score1; @@ -41,22 +41,7 @@ BOOST_AUTO_TEST_CASE(ArmaComparisonPCATest) BOOST_REQUIRE_SMALL(eigVal1[i], 1e-15); else BOOST_REQUIRE_CLOSE(eigVal[i], eigVal1[i], 0.0001); - }*/ - - mat coeff, coeff1; - vec eigVal, eigVal1; - mat score, score1; - - mat data = randu(20000, 256); - - PCA p; - - p.Apply(data.t(), score1, eigVal1, coeff1); - cout << size(data.t()) << endl; - cout << "_____________" << endl; - cout << size(eigVal1) << endl; - cout << "______________" << endl; - cout << size(score1) << endl; + } } /** From 9399cd37d122d0eabc24bd38a4cc2966c59d9da2 Mon Sep 17 00:00:00 2001 From: Jain Date: Mon, 20 Jun 2016 12:44:44 +0530 Subject: [PATCH 07/14] backported index_min and index_max --- src/mlpack/core/arma_extend/CMakeLists.txt | 1 + src/mlpack/core/arma_extend/arma_extend.hpp | 2 + .../core/arma_extend/fn_index_min_max.hpp | 322 ++++++++++++++++++ 3 files changed, 325 insertions(+) create mode 100644 src/mlpack/core/arma_extend/fn_index_min_max.hpp diff --git a/src/mlpack/core/arma_extend/CMakeLists.txt b/src/mlpack/core/arma_extend/CMakeLists.txt index db0c2212c38..4cdcb38538a 100644 --- a/src/mlpack/core/arma_extend/CMakeLists.txt +++ b/src/mlpack/core/arma_extend/CMakeLists.txt @@ -4,6 +4,7 @@ set(SOURCES arma_extend.hpp fn_ccov.hpp fn_ind2sub.hpp + fn_index_min_max.hpp glue_ccov_meat.hpp glue_ccov_proto.hpp hdf5_misc.hpp diff --git a/src/mlpack/core/arma_extend/arma_extend.hpp b/src/mlpack/core/arma_extend/arma_extend.hpp index 12765c775f2..7ba3a1676c8 100644 --- a/src/mlpack/core/arma_extend/arma_extend.hpp +++ b/src/mlpack/core/arma_extend/arma_extend.hpp @@ -68,6 +68,8 @@ namespace arma { // index to subscript and vice versa #include "fn_ind2sub.hpp" + // to find index of min/max value + #include "fn_index_min_max.hpp" // inplace_reshape() #include "fn_inplace_reshape.hpp" diff --git a/src/mlpack/core/arma_extend/fn_index_min_max.hpp b/src/mlpack/core/arma_extend/fn_index_min_max.hpp new file mode 100644 index 00000000000..98db7c88ec8 --- /dev/null +++ b/src/mlpack/core/arma_extend/fn_index_min_max.hpp @@ -0,0 +1,322 @@ +#if (ARMA_VERSION_MAJOR < 7 && ARMA_VERSION_MINOR < 200) +template +inline +typename arma_not_cx::result +op_min::min_with_index(const Proxy& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_pos(); + uword best_index = 0; + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i +inline +typename arma_not_cx::result +op_min::min_with_index(const ProxyCube& P, uword& index_of_min_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "min(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_pos(); + uword best_index = 0; + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + const eT tmp = A[i]; + + if(tmp < best_val) { best_val = tmp; best_index = i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + uword count = 0; + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp = P.at(row,col,slice); + + if(tmp < best_val) { best_val = tmp; best_index = count; } + + ++count; + } + } + + index_of_min_val = best_index; + + return best_val; + } + + template +inline +typename arma_not_cx::result +op_max::max_with_index(const Proxy& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_neg(); + uword best_index = 0; + + if(Proxy::use_at == false) + { + typedef typename Proxy::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i best_val) { best_val = tmp; best_index = i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + if(n_rows == 1) + { + for(uword i=0; i < n_cols; ++i) + { + const eT tmp = P.at(0,i); + + if(tmp > best_val) { best_val = tmp; best_index = i; } + } + } + else + if(n_cols == 1) + { + for(uword i=0; i < n_rows; ++i) + { + const eT tmp = P.at(i,0); + + if(tmp > best_val) { best_val = tmp; best_index = i; } + } + } + else + { + uword count = 0; + + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT tmp = P.at(row,col); + + if(tmp > best_val) { best_val = tmp; best_index = count; } + + ++count; + } + } + } + + index_of_max_val = best_index; + + return best_val; + } + +template +inline +typename arma_not_cx::result +op_max::max_with_index(const ProxyCube& P, uword& index_of_max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) + { + arma_debug_check(true, "max(): object has no elements"); + + return Datum::nan; + } + + eT best_val = priv::most_neg(); + uword best_index = 0; + + if(ProxyCube::use_at == false) + { + typedef typename ProxyCube::ea_type ea_type; + + ea_type A = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) + { + const eT tmp = A[i]; + + if(tmp > best_val) { best_val = tmp; best_index = i; } + } + } + else + { + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + uword count = 0; + + for(uword slice=0; slice < n_slices; ++slice) + for(uword col=0; col < n_cols; ++col ) + for(uword row=0; row < n_rows; ++row ) + { + const eT tmp = P.at(row,col,slice); + + if(tmp > best_val) { best_val = tmp; best_index = count; } + + ++count; + } + } + + index_of_max_val = best_index; + + return best_val; + } + +template +inline +arma_warn_unused +uword +Base::index_min() const + { + const Proxy P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_min(): object has no elements"); + } + else + { + op_min::min_with_index(P, index); + } + + return index; + } + + +template +inline +arma_warn_unused +uword +Base::index_max() const + { + const Proxy P( (*this).get_ref() ); + + uword index = 0; + + if(P.get_n_elem() == 0) + { + arma_debug_check(true, "index_max(): object has no elements"); + } + else + { + op_max::max_with_index(P, index); + } + + return index; + } +#endif \ No newline at end of file From 0abec2cbb44a0556b129485b485bae8d07986af7 Mon Sep 17 00:00:00 2001 From: nilayjain Date: Mon, 20 Jun 2016 16:31:27 +0000 Subject: [PATCH 08/14] added some timers for analysis --- .../core/arma_extend/fn_index_min_max.hpp | 5 +- src/mlpack/methods/CMakeLists.txt | 2 +- .../methods/edge_boxes/edge_boxes_main.cpp | 4 +- .../edge_boxes/feature_extraction_impl.hpp | 63 ++++++++++++++++++- src/mlpack/tests/CMakeLists.txt | 4 +- 5 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/mlpack/core/arma_extend/fn_index_min_max.hpp b/src/mlpack/core/arma_extend/fn_index_min_max.hpp index 98db7c88ec8..c605bab8b93 100644 --- a/src/mlpack/core/arma_extend/fn_index_min_max.hpp +++ b/src/mlpack/core/arma_extend/fn_index_min_max.hpp @@ -1,4 +1,5 @@ -#if (ARMA_VERSION_MAJOR < 7 && ARMA_VERSION_MINOR < 200) +#if (ARMA_VERSION_MAJOR < 7\ + || (ARMA_VERSION_MAJOR == 7 && ARMA_VERSION_MINOR < 200)) template inline typename arma_not_cx::result @@ -319,4 +320,4 @@ Base::index_max() const return index; } -#endif \ No newline at end of file +#endif diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt index adb67489b67..d0ba0ee3a66 100644 --- a/src/mlpack/methods/CMakeLists.txt +++ b/src/mlpack/methods/CMakeLists.txt @@ -24,7 +24,7 @@ set(DIRS det emst edge_boxes - fastmks +# fastmks gmm hmm hoeffding_trees diff --git a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp index 4afa4cbc958..a54f6a4dd8c 100644 --- a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp +++ b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp @@ -11,8 +11,9 @@ using namespace mlpack; using namespace mlpack::structured_tree; using namespace std; -int main() +int main(int argc, char** argv) { + CLI::ParseCommandLine(argc, argv); /* :param options: num_images: number of images in the dataset. @@ -79,3 +80,4 @@ int main() } + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index b6d2ba9a3d7..093fcf24fe9 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -24,8 +24,53 @@ StructuredForests(FeatureParameters F) params = F; } +/* +template +MatType StructuredForests:: +LoadData(MatType const &Images, MatType const &boundaries,\ + MatType const &segmentations) +{ + const size_t num_Images = this->params.num_Images; + const size_t rowSize = this->params.RowSize(); + const size_t colSize = this->params.ColSize(); + MatType input_data(num_Images * rowSize * 5, colSize); + // we store the input data as follows: + // Images (3), boundaries (1), segmentations (1). + size_t loop_iter = num_Images * 5; + size_t row_idx = 0; + size_t col_i = 0, col_s = 0, col_b = 0; + for(size_t i = 0; i < loop_iter; ++i) + { + if (i % 5 == 4) + { + input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ + colSize - 1) = MatType(segmentations.colptr(col_s),\ + colSize, rowSize).t(); + ++col_s; + } + else if (i % 5 == 3) + { + input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ + colSize - 1) = MatType(boundaries.colptr(col_b),\ + colSize, rowSize).t(); + ++col_b; + } + else + { + input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ + colSize - 1) = MatType(Images.colptr(col_i), + colSize, rowSize).t(); + ++col_i; + } + row_idx += rowSize; + } + return input_data; +} + +*/ + /** - * Get Dimensions of Features + * Get DImensions of Features * @param FtrDim Output vector that contains the result */ template @@ -804,7 +849,7 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ const size_t nFtrDim = FtrDim(0) + FtrDim(1); const size_t nSmpFtrDim = (size_t)(nFtrDim * fraction); - + size_t time=0; for(size_t i = 0; i < numTree; ++i) { //Implement the logic for if data already exists. @@ -892,7 +937,11 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ } CubeType SSFtr, RegFtr; + Timer::Start("get_features"); + this->GetFeatures(Img, loc, RegFtr, SSFtr, table); + Timer::Stop("get_features"); + //randomly sample 70 values each from reg_ftr and ss_ftr. /* CubeType ftr(140, 1000, 13); @@ -907,7 +956,7 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ // have to do this or we can overload the CopyMakeBorder to support MatType. s.slice(0) = segs; CubeType in_segs; - this->CopyMakeBorder(s, gRad, gRad, gRad, + this->CopyMakeBorder(s, gRad, gRad, gRad,\ gRad, in_segs); for(size_t i = 0; i < loc.n_rows; ++i) @@ -935,6 +984,7 @@ Discretize(const MatType& labels, const size_t nClass,\ // lbls : 20000 * 256. // nSample: number of samples for clustering structured labels 256 // nClass: number of classes (clusters) for binary splits. 2 + Timer::Start("other_discretize"); arma::uvec lis1(nSample); for (size_t i = 0; i < lis1.n_elem; ++i) @@ -967,14 +1017,21 @@ Discretize(const MatType& labels, const size_t nClass,\ // so most representative label is: labels.row(ind). // apply pca + Timer::Stop("other_discretize"); + Timer::Start("pca_timer"); MatType coeff, transformedData; arma::vec eigVal; mlpack::pca::PCA p; p.Apply(zs.t(), transformedData, eigVal, coeff); // we take only first row in transformedData (256 * 20000) as dim = 1. + Timer::Stop("pca_timer"); + Timer::Start("other_discretize"); + //std::cout << Timer::Get("pca_timer") << std::endl; DiscreteLabels = arma::conv_to::from(transformedData.row(0).t() > 0); + Timer::Stop("other_discretize"); } return ind; + } } // namespace structured_tree } // namespace mlpack diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index abf91d59de9..95bc6949b7d 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -18,7 +18,7 @@ add_executable(mlpack_test distribution_test.cpp emst_test.cpp edge_boxes_test.cpp - fastmks_test.cpp +# fastmks_test.cpp feedforward_network_test.cpp gmm_test.cpp hmm_test.cpp @@ -64,7 +64,7 @@ add_executable(mlpack_test sgd_test.cpp serialization.hpp serialization.cpp - serialization_test.cpp +# serialization_test.cpp softmax_regression_test.cpp sort_policy_test.cpp sparse_autoencoder_test.cpp From 987f0790f745314a2117a87f28a346b05a8a2a2d Mon Sep 17 00:00:00 2001 From: nilayjain Date: Fri, 24 Jun 2016 16:59:39 +0000 Subject: [PATCH 09/14] added IndexMin function --- src/mlpack/core/arma_extend/CMakeLists.txt | 1 - src/mlpack/core/arma_extend/arma_extend.hpp | 3 +- .../core/arma_extend/fn_index_min_max.hpp | 323 ------------------ .../methods/edge_boxes/feature_extraction.hpp | 5 +- .../edge_boxes/feature_extraction_impl.hpp | 23 +- 5 files changed, 23 insertions(+), 332 deletions(-) delete mode 100644 src/mlpack/core/arma_extend/fn_index_min_max.hpp diff --git a/src/mlpack/core/arma_extend/CMakeLists.txt b/src/mlpack/core/arma_extend/CMakeLists.txt index 4cdcb38538a..db0c2212c38 100644 --- a/src/mlpack/core/arma_extend/CMakeLists.txt +++ b/src/mlpack/core/arma_extend/CMakeLists.txt @@ -4,7 +4,6 @@ set(SOURCES arma_extend.hpp fn_ccov.hpp fn_ind2sub.hpp - fn_index_min_max.hpp glue_ccov_meat.hpp glue_ccov_proto.hpp hdf5_misc.hpp diff --git a/src/mlpack/core/arma_extend/arma_extend.hpp b/src/mlpack/core/arma_extend/arma_extend.hpp index 7ba3a1676c8..1978c45251a 100644 --- a/src/mlpack/core/arma_extend/arma_extend.hpp +++ b/src/mlpack/core/arma_extend/arma_extend.hpp @@ -68,8 +68,7 @@ namespace arma { // index to subscript and vice versa #include "fn_ind2sub.hpp" - // to find index of min/max value - #include "fn_index_min_max.hpp" + // inplace_reshape() #include "fn_inplace_reshape.hpp" diff --git a/src/mlpack/core/arma_extend/fn_index_min_max.hpp b/src/mlpack/core/arma_extend/fn_index_min_max.hpp deleted file mode 100644 index c605bab8b93..00000000000 --- a/src/mlpack/core/arma_extend/fn_index_min_max.hpp +++ /dev/null @@ -1,323 +0,0 @@ -#if (ARMA_VERSION_MAJOR < 7\ - || (ARMA_VERSION_MAJOR == 7 && ARMA_VERSION_MINOR < 200)) -template -inline -typename arma_not_cx::result -op_min::min_with_index(const Proxy& P, uword& index_of_min_val) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const uword n_elem = P.get_n_elem(); - - if(n_elem == 0) - { - arma_debug_check(true, "min(): object has no elements"); - - return Datum::nan; - } - - eT best_val = priv::most_pos(); - uword best_index = 0; - - if(Proxy::use_at == false) - { - typedef typename Proxy::ea_type ea_type; - - ea_type A = P.get_ea(); - - for(uword i=0; i -inline -typename arma_not_cx::result -op_min::min_with_index(const ProxyCube& P, uword& index_of_min_val) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const uword n_elem = P.get_n_elem(); - - if(n_elem == 0) - { - arma_debug_check(true, "min(): object has no elements"); - - return Datum::nan; - } - - eT best_val = priv::most_pos(); - uword best_index = 0; - - if(ProxyCube::use_at == false) - { - typedef typename ProxyCube::ea_type ea_type; - - ea_type A = P.get_ea(); - - for(uword i=0; i < n_elem; ++i) - { - const eT tmp = A[i]; - - if(tmp < best_val) { best_val = tmp; best_index = i; } - } - } - else - { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_slices = P.get_n_slices(); - - uword count = 0; - - for(uword slice=0; slice < n_slices; ++slice) - for(uword col=0; col < n_cols; ++col ) - for(uword row=0; row < n_rows; ++row ) - { - const eT tmp = P.at(row,col,slice); - - if(tmp < best_val) { best_val = tmp; best_index = count; } - - ++count; - } - } - - index_of_min_val = best_index; - - return best_val; - } - - template -inline -typename arma_not_cx::result -op_max::max_with_index(const Proxy& P, uword& index_of_max_val) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const uword n_elem = P.get_n_elem(); - - if(n_elem == 0) - { - arma_debug_check(true, "max(): object has no elements"); - - return Datum::nan; - } - - eT best_val = priv::most_neg(); - uword best_index = 0; - - if(Proxy::use_at == false) - { - typedef typename Proxy::ea_type ea_type; - - ea_type A = P.get_ea(); - - for(uword i=0; i best_val) { best_val = tmp; best_index = i; } - } - } - else - { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - - if(n_rows == 1) - { - for(uword i=0; i < n_cols; ++i) - { - const eT tmp = P.at(0,i); - - if(tmp > best_val) { best_val = tmp; best_index = i; } - } - } - else - if(n_cols == 1) - { - for(uword i=0; i < n_rows; ++i) - { - const eT tmp = P.at(i,0); - - if(tmp > best_val) { best_val = tmp; best_index = i; } - } - } - else - { - uword count = 0; - - for(uword col=0; col < n_cols; ++col) - for(uword row=0; row < n_rows; ++row) - { - const eT tmp = P.at(row,col); - - if(tmp > best_val) { best_val = tmp; best_index = count; } - - ++count; - } - } - } - - index_of_max_val = best_index; - - return best_val; - } - -template -inline -typename arma_not_cx::result -op_max::max_with_index(const ProxyCube& P, uword& index_of_max_val) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const uword n_elem = P.get_n_elem(); - - if(n_elem == 0) - { - arma_debug_check(true, "max(): object has no elements"); - - return Datum::nan; - } - - eT best_val = priv::most_neg(); - uword best_index = 0; - - if(ProxyCube::use_at == false) - { - typedef typename ProxyCube::ea_type ea_type; - - ea_type A = P.get_ea(); - - for(uword i=0; i < n_elem; ++i) - { - const eT tmp = A[i]; - - if(tmp > best_val) { best_val = tmp; best_index = i; } - } - } - else - { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_slices = P.get_n_slices(); - - uword count = 0; - - for(uword slice=0; slice < n_slices; ++slice) - for(uword col=0; col < n_cols; ++col ) - for(uword row=0; row < n_rows; ++row ) - { - const eT tmp = P.at(row,col,slice); - - if(tmp > best_val) { best_val = tmp; best_index = count; } - - ++count; - } - } - - index_of_max_val = best_index; - - return best_val; - } - -template -inline -arma_warn_unused -uword -Base::index_min() const - { - const Proxy P( (*this).get_ref() ); - - uword index = 0; - - if(P.get_n_elem() == 0) - { - arma_debug_check(true, "index_min(): object has no elements"); - } - else - { - op_min::min_with_index(P, index); - } - - return index; - } - - -template -inline -arma_warn_unused -uword -Base::index_max() const - { - const Proxy P( (*this).get_ref() ); - - uword index = 0; - - if(P.get_n_elem() == 0) - { - arma_debug_check(true, "index_max(): object has no elements"); - } - else - { - op_max::max_with_index(P, index); - } - - return index; - } -#endif diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp index 16fa8c1ad92..b94dad5f3bc 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -18,10 +18,10 @@ class StructuredForests { private: FeatureParameters params; + static constexpr double eps = 1e-20; public: - static constexpr double eps = 1e-20; StructuredForests(FeatureParameters F); /* MatType LoadData(MatType const &images, MatType const &boundaries,\ @@ -86,6 +86,8 @@ class StructuredForests void PDist(const CubeType& features, const arma::uvec& grid_pos, CubeType& Output); + size_t IndexMin(arma::vec& k); + size_t Discretize(const MatType& labels, const size_t nClass,\ const size_t nSample, arma::vec& DiscreteLabels); }; @@ -97,3 +99,4 @@ class StructuredForests #endif + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index 093fcf24fe9..3fe1f860313 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -19,10 +19,7 @@ namespace structured_tree { */ template StructuredForests:: -StructuredForests(FeatureParameters F) -{ - params = F; -} +StructuredForests(FeatureParameters F) : params(F) {} /* template @@ -973,6 +970,21 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ } } +template +size_t StructuredForests:: +IndexMin(arma::vec& k) +{ + double s = k(0); size_t ind = 0; + for (size_t i = 1; i < k.n_elem; ++i) + { + if (k(i) < s) + { + s = k(i); + ind = i; + } + } + return ind; +} // returns the index of the most representative label, and discretizes structured // label to discreet classes in matrix subLbls. (this is a vector if nClass = 2) template @@ -1013,7 +1025,8 @@ Discretize(const MatType& labels, const size_t nClass,\ else { //find most representative label (closest to mean) - ind = arma::sum(arma::abs(zs), 0).index_min(); + arma::vec k = arma::sum(arma::abs(zs), 0); + ind = IndexMin(k); // so most representative label is: labels.row(ind). // apply pca From f041da4deac2d9af61df51e91a356f8acba78326 Mon Sep 17 00:00:00 2001 From: nilayjain Date: Mon, 27 Jun 2016 04:44:41 +0000 Subject: [PATCH 10/14] timing tests, computing gradient by sobel filter --- .../methods/edge_boxes/edge_boxes_main.cpp | 1 - .../methods/edge_boxes/feature_extraction.hpp | 24 +-- .../edge_boxes/feature_extraction_impl.hpp | 169 ++++++------------ 3 files changed, 62 insertions(+), 132 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp index a54f6a4dd8c..9d9fcb25265 100644 --- a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp +++ b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp @@ -46,7 +46,6 @@ int main(int argc, char** argv) */ FeatureParameters params = FeatureParameters(); - params.NumImages(2); params.RowSize(321); params.ColSize(481); diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp index b94dad5f3bc..701d12a4fa9 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -6,8 +6,6 @@ */ #ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_HPP #define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_HPP -//#define INF 999999.9999 -//#define EPS 1E-20 #include #include "feature_parameters.hpp" namespace mlpack { @@ -18,14 +16,11 @@ class StructuredForests { private: FeatureParameters params; - static constexpr double eps = 1e-20; public: - StructuredForests(FeatureParameters F); -/* MatType LoadData(MatType const &images, MatType const &boundaries,\ - MatType const &segmentations);*/ + StructuredForests(FeatureParameters F); void PrepareData(const MatType& Images, const MatType& Boundaries,\ const MatType& Segmentations); @@ -54,13 +49,17 @@ class StructuredForests const arma::vec& table); void BilinearInterpolation(const MatType& src, - size_t height, size_t width, - MatType& dst); + const size_t height, + const size_t width, + MatType& dst); - void SepFilter2D(CubeType &InOutImage, const arma::vec& kernel, const size_t radius); + void Convolution(CubeType &InOutImage, const MatType& Filter, const size_t radius); void ConvTriangle(CubeType &InImage, const size_t radius); + void ConvTriangle2(CubeType& InImage, const size_t radius, CubeType& Output); + + void Gradient(const CubeType& InImage, MatType& Magnitude, MatType& Orientation); @@ -68,9 +67,10 @@ class StructuredForests void MaxAndLoc(CubeType &mag, arma::umat &Location, MatType& MaxVal) const; void Histogram(const MatType& Magnitude, - const MatType& Orientation, - size_t downscale, size_t interp, - CubeType& HistArr); + const MatType& Orientation, + const size_t downscale, + const size_t interp, + CubeType& HistArr); void ViewAsWindows(const CubeType& channels, const arma::umat& loc, CubeType& features); diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index 3fe1f860313..f70adc72516 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -10,6 +10,8 @@ #include "feature_extraction.hpp" #include + +using namespace mlpack::ann; namespace mlpack { namespace structured_tree { @@ -21,53 +23,8 @@ template StructuredForests:: StructuredForests(FeatureParameters F) : params(F) {} -/* -template -MatType StructuredForests:: -LoadData(MatType const &Images, MatType const &boundaries,\ - MatType const &segmentations) -{ - const size_t num_Images = this->params.num_Images; - const size_t rowSize = this->params.RowSize(); - const size_t colSize = this->params.ColSize(); - MatType input_data(num_Images * rowSize * 5, colSize); - // we store the input data as follows: - // Images (3), boundaries (1), segmentations (1). - size_t loop_iter = num_Images * 5; - size_t row_idx = 0; - size_t col_i = 0, col_s = 0, col_b = 0; - for(size_t i = 0; i < loop_iter; ++i) - { - if (i % 5 == 4) - { - input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ - colSize - 1) = MatType(segmentations.colptr(col_s),\ - colSize, rowSize).t(); - ++col_s; - } - else if (i % 5 == 3) - { - input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ - colSize - 1) = MatType(boundaries.colptr(col_b),\ - colSize, rowSize).t(); - ++col_b; - } - else - { - input_data.submat(row_idx, 0, row_idx + rowSize - 1,\ - colSize - 1) = MatType(Images.colptr(col_i), - colSize, rowSize).t(); - ++col_i; - } - row_idx += rowSize; - } - return input_data; -} - -*/ - /** - * Get DImensions of Features + * Get Dimensions of Features * @param FtrDim Output vector that contains the result */ template @@ -80,13 +37,7 @@ GetFeatureDimension(arma::vec& FtrDim) const size_t pSize = this->params.PSize(); const size_t numCell = this->params.NumCell(); const size_t numOrient = this->params.NumOrient(); - - size_t nColorCh; - if (this->params.RGBD() == 0) - nColorCh = 3; - else - nColorCh = 4; - + const size_t nColorCh = params.RGBD() == 0 ? 3 : 4; const size_t nCh = nColorCh + 2 * (1 + numOrient); FtrDim[0] = std::pow((pSize / shrink) , 2) * nCh; FtrDim[1] = std::pow(numCell , 2) * (std::pow (numCell, 2) - 1) / 2 * nCh; @@ -169,6 +120,10 @@ DistanceTransform2D(MatType &Im, const double inf) * @param Im Input binary Image whose distance transform is to be found. * @param on if on == 1, 1 is taken as boundaries and vice versa. * @param Out Output Image. + * This is the discription of the paper which discribes the approach + * for this algorithm : Distance Transforms of Sampled Functions, + * P. Felzenszwalb, D. Huttenlocher + * Theory of Computing, Vol. 8, No. 19, September 2012 */ template void StructuredForests:: @@ -252,7 +207,7 @@ RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ //see how to calculate this efficiently. numpy.dot does this. CubeType xyz(InImage.n_rows, InImage.n_cols, rgb2xyz.n_cols); - + /* for (size_t i = 0; i < InImage.slice(0).n_elem; ++i) { double r = InImage.slice(0)(i); @@ -263,6 +218,15 @@ RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ xyz.slice(1)(i) = 0.222015 * r + 0.706655 * g + 0.071330 * b; xyz.slice(2)(i) = 0.020183 * r + 0.129553 * g + 0.939180 * b; } + */ + + + xyz.slice(0) = 0.430574 * InImage.slice(0) + 0.341550 * InImage.slice(1)\ + + 0.178325 * InImage.slice(2); + xyz.slice(1) = 0.222015 * InImage.slice(0) + 0.706655 * InImage.slice(1)\ + + 0.071330 * InImage.slice(2); + xyz.slice(2) = 0.020183 * InImage.slice(0) + 0.129553 * InImage.slice(1)\ + + 0.939180 * InImage.slice(2); MatType nz(InImage.n_rows, InImage.n_cols); nz = 1.0 / ( xyz.slice(0) + (15 * xyz.slice(1) ) + @@ -293,7 +257,8 @@ RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ template void StructuredForests:: BilinearInterpolation(const MatType& src, - size_t height, size_t width, + const size_t height, + const size_t width, MatType& dst) { dst = MatType(height, width); @@ -330,15 +295,14 @@ BilinearInterpolation(const MatType& src, */ template void StructuredForests:: -SepFilter2D(CubeType &InOutImage, const arma::vec& kernel, const size_t radius) +Convolution(CubeType &InOutImage, const MatType& Filter, const size_t radius) { CubeType OutImage; this->CopyMakeBorder(InOutImage, radius, radius, radius, radius, OutImage); arma::vec row_res, col_res; // reverse InOutImage and OutImage to avoid making an extra matrix. - // InImage is renamed to InOutImage in this function for this reason only. - MatType k_mat = kernel * kernel.t(); + // InImage is renamed to InOutImage in this function for this reason only. for(size_t k = 0; k < OutImage.n_slices; ++k) { for(size_t j = radius; j < OutImage.n_cols - radius; ++j) @@ -348,7 +312,7 @@ SepFilter2D(CubeType &InOutImage, const arma::vec& kernel, const size_t radius) InOutImage(i - radius, j - radius, k) = arma::accu(OutImage.slice(k)\ .submat(i - radius, j - radius,\ - i + radius, j + radius) % k_mat); + i + radius, j + radius) % Filter); } } } @@ -373,8 +337,8 @@ ConvTriangle(CubeType &InImage, const size_t radius) const double p = 12.0 / radius / (radius + 2) - 2; arma::vec kernel = {1, p, 1}; kernel /= (p + 2); - - this->SepFilter2D(InImage, kernel, radius); + MatType Filter = kernel * kernel.t(); + this->Convolution(InImage, Filter, radius); } else { @@ -390,7 +354,8 @@ ConvTriangle(CubeType &InImage, const size_t radius) kernel(i) = r--; kernel /= std::pow(radius + 1, 2); - this->SepFilter2D(InImage, kernel, radius); + MatType Filter = kernel * kernel.t(); + this->Convolution(InImage, Filter, radius); } } @@ -408,7 +373,7 @@ MaxAndLoc(CubeType& mag, arma::umat& Location, MatType& MaxVal) const for(size_t j = 0; j < mag.n_cols; ++j) { /*can use -infinity here*/ - double max = std::numeric_limits::min(); + double max = -DBL_MAX; for(size_t k = 0; k < mag.n_slices; ++k) { if(mag(i, j, k) > max) @@ -423,7 +388,7 @@ MaxAndLoc(CubeType& mag, arma::umat& Location, MatType& MaxVal) const } /** - * Computes Gradient, Magnitude & Orientation. + * Computes Gradient, Magnitude & Orientation of the Edges. */ template void StructuredForests:: @@ -432,53 +397,22 @@ Gradient(const CubeType& InImage, MatType& Orientation) { const size_t grdNormRad = this->params.GrdNormRad(); - CubeType dx(InImage.n_rows, InImage.n_cols, InImage.n_slices), - dy(InImage.n_rows, InImage.n_cols, InImage.n_slices); - - dx.zeros(); - dy.zeros(); - - /* - From MATLAB documentation: - [FX,FY] = gradient(F), where F is a matrix, returns the - x and y components of the two-dImensional numerical gradient. - FX corresponds to ∂F/∂x, the differences in x (horizontal) direction. - FY corresponds to ∂F/∂y, the differences in the y (vertical) direction. - */ - - - /* - gradient calculates the central difference for interior data points. - For example, consider a matrix with unit-spaced data, A, that has - horizontal gradient G = gradient(A). The interior gradient values, G(:,j), are: - - G(:,j) = 0.5*(A(:,j+1) - A(:,j-1)); - where j varies between 2 and N-1, where N is size(A,2). - The gradient values along the edges of the matrix are calculated with single-sided differences, so that + // calculate gradients using sobel filter. + CubeType dx = InImage; + CubeType dy = InImage; - G(:,1) = A(:,2) - A(:,1); - G(:,N) = A(:,N) - A(:,N-1); - - The spacing between points in each direction is assumed to be one. - */ - for (size_t i = 0; i < InImage.n_slices; ++i) - { - dx.slice(i).col(0) = InImage.slice(i).col(1) - InImage.slice(i).col(0); - dx.slice(i).col(InImage.n_cols - 1) = InImage.slice(i).col(InImage.n_cols - 1) - - InImage.slice(i).col(InImage.n_cols - 2); - - for (size_t j = 1; j < InImage.n_cols-1; j++) - dx.slice(i).col(j) = 0.5 * ( InImage.slice(i).col(j+1) - InImage.slice(i).col(j) ); + MatType gx, gy; + gx << -1 << 0 << 1 << arma::endr + << -2 << 0 << 2 << arma::endr + << -1 << 0 << 1; - // do same for dy. - dy.slice(i).row(0) = InImage.slice(i).row(1) - InImage.slice(i).row(0); - dy.slice(i).row(InImage.n_rows - 1) = InImage.slice(i).row(InImage.n_rows - 1) - - InImage.slice(i).row(InImage.n_rows - 2); + gy << -1 << -2 << -1 << arma::endr + << 0 << 0 << 0 << arma::endr + << 1 << 2 << 1; - for (size_t j = 1; j < InImage.n_rows-1; j++) - dy.slice(i).row(j) = 0.5 * ( InImage.slice(i).row(j+1) - InImage.slice(i).row(j) ); - } + Convolution(dx, gx, 2); + Convolution(dy, gy, 2); CubeType mag(InImage.n_rows, InImage.n_cols, InImage.n_slices); for (size_t i = 0; i < InImage.n_slices; ++i) @@ -491,7 +425,7 @@ Gradient(const CubeType& InImage, this->MaxAndLoc(mag, Location, Magnitude); if(grdNormRad != 0) { - //we have to do this or override ConvTriangle and SepFilter2D methods. + //we have to do this or override ConvTriangle and Convolution methods. CubeType mag2(InImage.n_rows, InImage.n_cols, 1); mag2.slice(0) = Magnitude; this->ConvTriangle(mag2, grdNormRad); @@ -527,8 +461,9 @@ Gradient(const CubeType& InImage, template void StructuredForests:: Histogram(const MatType& Magnitude, - const MatType& Orientation, - size_t downscale, size_t interp, + const MatType& Orientation, + const size_t downscale, + const size_t interp, CubeType& HistArr) { @@ -568,9 +503,7 @@ Histogram(const MatType& Magnitude, } HistArr = HistArr / downscale; - - for (size_t i = 0; i < HistArr.n_slices; ++i) - HistArr.slice(i) = arma::square(HistArr.slice(i)); + HistArr = arma::square(HistArr); } /** @@ -615,7 +548,6 @@ GetShrunkChannels(const CubeType& InImage, CubeType& reg_ch,\ (luv.n_rows * scale), (luv.n_cols * scale), Img.slice(slice_idx)); } - this->ConvTriangle(Img, grdSmoothRad); MatType Magnitude(InImage.n_rows, InImage.n_cols), Orientation(InImage.n_rows, InImage.n_cols); @@ -646,11 +578,14 @@ GetShrunkChannels(const CubeType& InImage, CubeType& reg_ch,\ if (regSmoothRad > 1.0) this->ConvTriangle(channels, (size_t) (std::round(regSmoothRad)) ); + else this->ConvTriangle(channels, regSmoothRad); + if (ssSmoothRad > 1.0) this->ConvTriangle(channels, (size_t) (std::round(ssSmoothRad)) ); + else this->ConvTriangle(channels, ssSmoothRad); @@ -889,13 +824,9 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ Img = MatType(Images.colptr(col_i), colSize, rowSize * 3).t() / 255; col_i += 3; //Img = Images.submat((j * 3) * rowSize, 0, ((j * 3) + 3) * rowSize - 1, colSize - 1); - //bnds = Boundaries.submat( j * rowSize, 0, \ - j * rowSize - 1, colSize - 1 ); bnds = MatType(Boundaries.colptr(col_b), colSize, rowSize).t(); col_b++; - //segs = Segmentations.submat( j * rowSize, 0, \ - j * rowSize - 1, colSize - 1 ); segs = MatType(Segmentations.colptr(col_s), colSize, rowSize).t(); col_s++; @@ -1025,7 +956,7 @@ Discretize(const MatType& labels, const size_t nClass,\ else { //find most representative label (closest to mean) - arma::vec k = arma::sum(arma::abs(zs), 0); + arma::vec k = arma::sum(arma::abs(zs), 0).t(); ind = IndexMin(k); // so most representative label is: labels.row(ind). From 08bd5501d30f0d5a3b8dad2d2f438055e7b75cd5 Mon Sep 17 00:00:00 2001 From: nilayjain Date: Mon, 27 Jun 2016 04:46:12 +0000 Subject: [PATCH 11/14] timing tests, computing gradient by sobel filter --- src/mlpack/methods/CMakeLists.txt | 2 +- src/mlpack/tests/CMakeLists.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt index d0ba0ee3a66..adb67489b67 100644 --- a/src/mlpack/methods/CMakeLists.txt +++ b/src/mlpack/methods/CMakeLists.txt @@ -24,7 +24,7 @@ set(DIRS det emst edge_boxes -# fastmks + fastmks gmm hmm hoeffding_trees diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 95bc6949b7d..abf91d59de9 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -18,7 +18,7 @@ add_executable(mlpack_test distribution_test.cpp emst_test.cpp edge_boxes_test.cpp -# fastmks_test.cpp + fastmks_test.cpp feedforward_network_test.cpp gmm_test.cpp hmm_test.cpp @@ -64,7 +64,7 @@ add_executable(mlpack_test sgd_test.cpp serialization.hpp serialization.cpp -# serialization_test.cpp + serialization_test.cpp softmax_regression_test.cpp sort_policy_test.cpp sparse_autoencoder_test.cpp From 2f8f3c7778a2746fdbe22c9e8958467b014ab481 Mon Sep 17 00:00:00 2001 From: Jain Date: Mon, 11 Jul 2016 19:59:43 +0530 Subject: [PATCH 12/14] more comments need to be added --- .../methods/edge_boxes/feature_extraction.hpp | 94 +++++++++++++-- .../edge_boxes/feature_extraction_impl.hpp | 110 ++++-------------- .../methods/edge_boxes/feature_parameters.hpp | 38 +++++- 3 files changed, 144 insertions(+), 98 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp index 701d12a4fa9..38304f6722e 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -19,25 +19,64 @@ class StructuredForests public: - + /** + * Constructor: stores all the parameters in an object + * of feature_parameters class. + * @param F FeatureParameters object which stores necessary parameters. + */ StructuredForests(FeatureParameters F); void PrepareData(const MatType& Images, const MatType& Boundaries,\ const MatType& Segmentations); + /** + * Get Dimensions of Features + * @param FtrDim Output vector that contains the result dimensions. + */ void GetFeatureDimension(arma::vec& FtrDim); + /** + * Computes distance transform of 1D vector f. + * @param f input vector whose distance transform is to be found. + * @param n size of the Output vector to be made. + * @param inf a large double value. + * @param d Output vector which stores distance transform of f. + */ void DistanceTransform1D(const arma::vec& f, const size_t n,\ const double inf, arma::vec& d); + /** + * Computes distance transform of a 2D array + * @param Im input array whose distance transform is to be found. + * @param inf a large double value. + */ void DistanceTransform2D(MatType &Im, const double inf); - + + /** + * euclidean distance transform of binary Image using squared distance + * @param Im Input binary Image whose distance transform is to be found. + * @param on if on == 1, 1 is taken as boundaries and vice versa. + * @param Out Output Image. + * This is the discription of the paper which discribes the approach + * for this algorithm : Distance Transforms of Sampled Functions, + * P. Felzenszwalb, D. Huttenlocher + * Theory of Computing, Vol. 8, No. 19, September 2012 + */ void DistanceTransformImage(const MatType& Im, double on, MatType& Out); void GetFeatures(const MatType &Image, arma::umat &loc,\ CubeType& RegFtr, CubeType& SSFtr,\ const arma::vec& table); - + + /** + * Makes a reflective border around an Image. + * @param InImage, Image which we have to make border around. + * @param top, border length (to be incremented) at top. + * @param left, border length at left. + * @param bottom, border length at bottom. + * @param right, border length at right. + * @param OutImage, Output Image. + */ void CopyMakeBorder(const CubeType& InImage, size_t top, size_t left, size_t bottom, size_t right, CubeType& OutImage); @@ -45,27 +84,68 @@ class StructuredForests void GetShrunkChannels(const CubeType& InImage, CubeType ®_ch,\ CubeType &ss_ch, const arma::vec& table); + /** + * Converts an Image in RGB color space to LUV color space. + * RGB must range in (0.0, 1.0). + * @param InImage Input Image in RGB color space. + * @param OutImage Ouptut Image in LUV color space. + */ void RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ const arma::vec& table); + /** + * Resizes the Image to the given size using Bilinear Interpolation + * @param src Input Image + * @param height Height of Output Image. + * @param width Width Out Output Image. + * @param dst Output Image resized to (height, width) + */ void BilinearInterpolation(const MatType& src, const size_t height, const size_t width, MatType& dst); + /** + * Applies a separable linear filter to an Image + * @param InOutImage Input/Output Contains the input Image, The final filtered Image is + * stored in this param. + * @param kernel Input Kernel vector to be applied on Image. + * @param radius amount, the Image should be padded before applying filter. + */ void Convolution(CubeType &InOutImage, const MatType& Filter, const size_t radius); + /** + * Applies a triangle filter on an Image. + * @param InImage Input/Output Image on which filter is applied. + * @param radius Decides the size of kernel to be applied on Image. + */ void ConvTriangle(CubeType &InImage, const size_t radius); - void ConvTriangle2(CubeType& InImage, const size_t radius, CubeType& Output); - + /** + * finds maximum of numbers on cube axis and stores maximum values + * in MaxVal, locations of maximum values in Location + * @param mag Input Cube for which we want to find max values and location + * @param Location Stores the slice number at which max value occurs + * @param MaxVal Stores the maximum value among all slices for a given (row, col). + */ + void MaxAndLoc(CubeType &mag, arma::umat &Location, MatType& MaxVal) const; + /** + * Computes Magnitude & Orientation of the Edges. + * Gradient of a function is a vector of partial derivatives in each direction. + * In this function the edges are calculated by applying the sobel filter on Image + * which is the same as finding the vectors of partial derivates. + * These "vectors" have a magnitude and a direction (orientation), which we + * calculate in this function. + * @param InImage Input Image for which we calculate Magnitude & Orientation. + * @param Magnitude Magnitude of the Edges + * @param Orientation Orientation of the Edges + */ void Gradient(const CubeType& InImage, MatType& Magnitude, MatType& Orientation); - void MaxAndLoc(CubeType &mag, arma::umat &Location, MatType& MaxVal) const; - + void Histogram(const MatType& Magnitude, const MatType& Orientation, const size_t downscale, diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index f70adc72516..5edbcddc6e0 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -15,18 +15,12 @@ using namespace mlpack::ann; namespace mlpack { namespace structured_tree { -/** - * Constructor: stores all the parameters in an object - * of feature_parameters class. - */ + template StructuredForests:: StructuredForests(FeatureParameters F) : params(F) {} -/** - * Get Dimensions of Features - * @param FtrDim Output vector that contains the result - */ + template void StructuredForests:: GetFeatureDimension(arma::vec& FtrDim) @@ -43,13 +37,7 @@ GetFeatureDimension(arma::vec& FtrDim) FtrDim[1] = std::pow(numCell , 2) * (std::pow (numCell, 2) - 1) / 2 * nCh; } -/** - * Computes distance transform of 1D vector f. - * @param f input vector whose distance transform is to be found. - * @param n size of the Output vector to be made. - * @param inf a large double value. - * @param d Output vector which stores distance transform of f. - */ + template void StructuredForests:: DistanceTransform1D(const arma::vec& f, const size_t n, const double inf, @@ -85,11 +73,6 @@ DistanceTransform1D(const arma::vec& f, const size_t n, const double inf, } } -/** - * Computes distance transform of a 2D array - * @param Im input array whose distance transform is to be found. - * @param inf a large double value. - */ template void StructuredForests:: @@ -115,16 +98,7 @@ DistanceTransform2D(MatType &Im, const double inf) } } -/** - * euclidean distance transform of binary Image using squared distance - * @param Im Input binary Image whose distance transform is to be found. - * @param on if on == 1, 1 is taken as boundaries and vice versa. - * @param Out Output Image. - * This is the discription of the paper which discribes the approach - * for this algorithm : Distance Transforms of Sampled Functions, - * P. Felzenszwalb, D. Huttenlocher - * Theory of Computing, Vol. 8, No. 19, September 2012 - */ + template void StructuredForests:: DistanceTransformImage(const MatType& Im, double on, MatType& Out) @@ -136,15 +110,7 @@ DistanceTransformImage(const MatType& Im, double on, MatType& Out) this->DistanceTransform2D(Out, inf); } -/** - * Makes a reflective border around an Image. - * @param InImage, Image which we have to make border around. - * @param top, border length (to be incremented) at top. - * @param left, border length at left. - * @param bottom, border length at bottom. - * @param right, border length at right. - * @param OutImage, Output Image. - */ + template void StructuredForests:: CopyMakeBorder(const CubeType& InImage, size_t top, @@ -158,6 +124,7 @@ CopyMakeBorder(const CubeType& InImage, size_t top, OutImage.slice(i).submat(top, left, InImage.n_rows + top - 1, InImage.n_cols + left - 1) = InImage.slice(i); + // first copy borders from left and right for(size_t j = 0; j < right; ++j) { OutImage.slice(i).col(InImage.n_cols + left + j).subvec(top, InImage.n_rows + top - 1) @@ -170,6 +137,7 @@ CopyMakeBorder(const CubeType& InImage, size_t top, = InImage.slice(i).col(left - 1 - j); } + // copy borders from top and bottom for(size_t j = 0; j < top; j++) { @@ -186,41 +154,19 @@ CopyMakeBorder(const CubeType& InImage, size_t top, } } -/** - * Converts an Image in RGB color space to LUV color space. - * RGB must range in (0.0, 1.0). - * @param InImage Input Image in RGB color space. - * @param OutImage Ouptut Image in LUV color space. - */ + template void StructuredForests:: RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ const arma::vec& table) { - //assert type is double or float. - - MatType rgb2xyz; rgb2xyz << 0.430574 << 0.222015 << 0.020183 << arma::endr << 0.341550 << 0.706655 << 0.129553 << arma::endr << 0.178325 << 0.071330 << 0.939180; - //see how to calculate this efficiently. numpy.dot does this. CubeType xyz(InImage.n_rows, InImage.n_cols, rgb2xyz.n_cols); - /* - for (size_t i = 0; i < InImage.slice(0).n_elem; ++i) - { - double r = InImage.slice(0)(i); - double g = InImage.slice(1)(i); - double b = InImage.slice(2)(i); - - xyz.slice(0)(i) = 0.430574 * r + 0.341550 * g + 0.178325 * b; - xyz.slice(1)(i) = 0.222015 * r + 0.706655 * g + 0.071330 * b; - xyz.slice(2)(i) = 0.020183 * r + 0.129553 * g + 0.939180 * b; - } - */ - - + xyz.slice(0) = 0.430574 * InImage.slice(0) + 0.341550 * InImage.slice(1)\ + 0.178325 * InImage.slice(2); xyz.slice(1) = 0.222015 * InImage.slice(0) + 0.706655 * InImage.slice(1)\ @@ -246,13 +192,7 @@ RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ - 13 * 0.468331) + 134 * maxi; } -/** - * Resizes the Image to the given size using Bilinear Interpolation - * @param src Input Image - * @param height Height of Output Image. - * @param width Width Out Output Image. - * @param dst Output Image resized to (height, width) - */ + /*Implement this function in a column major order.*/ template void StructuredForests:: @@ -286,13 +226,7 @@ BilinearInterpolation(const MatType& src, } } -/** - * Applies a separable linear filter to an Image - * @param InOutImage Input/Output Contains the input Image, The final filtered Image is - * stored in this param. - * @param kernel Input Kernel vector to be applied on Image. - * @param radius amount, the Image should be padded before applying filter. - */ + template void StructuredForests:: Convolution(CubeType &InOutImage, const MatType& Filter, const size_t radius) @@ -319,18 +253,14 @@ Convolution(CubeType &InOutImage, const MatType& Filter, const size_t radius) } -/** - * Applies a triangle filter on an Image. - * @param InImage Input/Output Image on which filter is applied. - * @param radius Decides the size of kernel to be applied on Image. - */ + template void StructuredForests:: ConvTriangle(CubeType &InImage, const size_t radius) { if (radius == 0) { - //nothing to do + // nothing to do } else if (radius <= 1) { @@ -359,9 +289,7 @@ ConvTriangle(CubeType &InImage, const size_t radius) } } -//just a helper function, can't use it for anything else -//finds max numbers on cube axis and returns max values, -// also stores the locations of max values in Location + template void StructuredForests:: MaxAndLoc(CubeType& mag, arma::umat& Location, MatType& MaxVal) const @@ -387,9 +315,7 @@ MaxAndLoc(CubeType& mag, arma::umat& Location, MatType& MaxVal) const } } -/** - * Computes Gradient, Magnitude & Orientation of the Edges. - */ + template void StructuredForests:: Gradient(const CubeType& InImage, @@ -403,6 +329,8 @@ Gradient(const CubeType& InImage, CubeType dy = InImage; MatType gx, gy; + + // values for sobel filter. gx << -1 << 0 << 1 << arma::endr << -2 << 0 << 2 << arma::endr << -1 << 0 << 1; @@ -414,6 +342,7 @@ Gradient(const CubeType& InImage, Convolution(dx, gx, 2); Convolution(dy, gy, 2); + // calculate the magnitudes of edges. CubeType mag(InImage.n_rows, InImage.n_cols, InImage.n_slices); for (size_t i = 0; i < InImage.n_slices; ++i) { @@ -431,6 +360,7 @@ Gradient(const CubeType& InImage, this->ConvTriangle(mag2, grdNormRad); Magnitude = Magnitude / (mag2.slice(0) + 0.01); } + MatType dx_mat(dx.n_rows, dx.n_cols),\ dy_mat(dy.n_rows, dy.n_cols); @@ -442,6 +372,8 @@ Gradient(const CubeType& InImage, dy_mat(i, j) = dy(i, j, Location(i, j)); } } + + // calculate Orientation of edges. Orientation = arma::atan(dy_mat / dx_mat); Orientation.transform( [](double val)\ diff --git a/src/mlpack/methods/edge_boxes/feature_parameters.hpp b/src/mlpack/methods/edge_boxes/feature_parameters.hpp index 895ab878adf..28b57ba97f9 100644 --- a/src/mlpack/methods/edge_boxes/feature_parameters.hpp +++ b/src/mlpack/methods/edge_boxes/feature_parameters.hpp @@ -1,5 +1,5 @@ /** - * @file feature_extraction_impl.hpp + * @file feature_parameters.hpp * @author Nilay Jain * * Implementation of feature parameter class. @@ -11,13 +11,14 @@ namespace mlpack { namespace structured_tree { -//This class holds all the fields for the FeatureExtraction class. +//! This class holds all the fields for the FeatureExtraction class. class FeatureParameters { public: FeatureParameters(){} //default constructor + //! getter and setter methods for all the fields in class. void NumImages(size_t value) { numImages = value; } size_t NumImages() const { return numImages; } @@ -70,22 +71,55 @@ class FeatureParameters double NumTree() const { return numTree; } private: + //! number of images in the dataset. size_t numImages; + + //! row size of images. size_t rowSize; + + //! column size of images. size_t colSize; + + //! 0 for RGB, 1 for RGB + depth. size_t rgbd; + + //! amount to shrink channels size_t shrink; + + //! number of orientations per gradient scale size_t numOrient; + + //! radius for image gradient smoothing size_t grdSmoothRad; + + //! radius for gradient normalization size_t grdNormRad; + + //! radius for regular channel smoothing size_t regSmoothRad; + + //! radius for similar channel smooothing size_t ssSmoothRad; + + //! fraction of features to use to train each tree double fraction; + + //! size of image patches size_t pSize; + + //! size of ground truth patches size_t gSize; + + //! number of positive patches per tree size_t numPos; + + //! number of negative patches per tree size_t numNeg; + + //! number of self similarity cells size_t numCell; + + //! number of trees in forest to train size_t numTree; }; From e9987d97a5b7d599c748a33170564f611ded0a60 Mon Sep 17 00:00:00 2001 From: Jain Date: Tue, 12 Jul 2016 18:55:54 +0530 Subject: [PATCH 13/14] comments --- .../methods/edge_boxes/feature_extraction.hpp | 91 +++++++++++++++++-- 1 file changed, 85 insertions(+), 6 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp index 38304f6722e..96f3a19e70a 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -54,16 +54,24 @@ class StructuredForests /** * euclidean distance transform of binary Image using squared distance - * @param Im Input binary Image whose distance transform is to be found. - * @param on if on == 1, 1 is taken as boundaries and vice versa. - * @param Out Output Image. * This is the discription of the paper which discribes the approach * for this algorithm : Distance Transforms of Sampled Functions, * P. Felzenszwalb, D. Huttenlocher * Theory of Computing, Vol. 8, No. 19, September 2012 + * @param Im Input binary Image whose distance transform is to be found. + * @param on if on == 1, 1 is taken as boundaries and vice versa. + * @param Out Output Image. */ void DistanceTransformImage(const MatType& Im, double on, MatType& Out); + /** + * Compute the Regular and Self Similarity Features of the given image. + * @param Image Given Input Image + * @param loc Locations at which features need to be extracted + * @param RegFtr Output the Regular Features + * @param SSFtr Output the Self Similarity Features + * @param table a helper vector required to convert image to LAB space. + */ void GetFeatures(const MatType &Image, arma::umat &loc,\ CubeType& RegFtr, CubeType& SSFtr,\ const arma::vec& table); @@ -81,12 +89,25 @@ class StructuredForests size_t left, size_t bottom, size_t right, CubeType& OutImage); + /** + * Augment image patch with multiple channels of information + * resulting in a feature vector. Shrink these channels by shrink + * size to reduce dimensions of the extracted candidate features. + * Refer to the following paper for details: + * Fast edge detection using structured forests + * Authors: Piotr Dollar, Larry Zitnick + * Published In: ICCV + * @param InImage Input Image. + * @param regCh Channels used in calculating Regular Features. + * @param ssCh Channels used in calculating Self Similarity Features. + * @param table a vector which helps in converting image from RGB to LUV. + */ void GetShrunkChannels(const CubeType& InImage, CubeType ®_ch,\ CubeType &ss_ch, const arma::vec& table); /** * Converts an Image in RGB color space to LUV color space. - * RGB must range in (0.0, 1.0). + * RGB values must be normalized i.e., range in [0.0, 1.0]. * @param InImage Input Image in RGB color space. * @param OutImage Ouptut Image in LUV color space. */ @@ -145,27 +166,85 @@ class StructuredForests MatType& Magnitude, MatType& Orientation); - + /** + * Compute Histogram of Oriented Gradients ( count occurrences of gradient + * orientation in localized portions of an image. ) + * @param Magnitude gradient magnitude + * @param Orientation gradient orientation + * @param downscale spatially downscaling factor + * @param interp: 1 (true) for interpolation over orientations + */ void Histogram(const MatType& Magnitude, const MatType& Orientation, const size_t downscale, const size_t interp, CubeType& HistArr); + /** + * store the features as a window view for a 16 x 16 patch, + * centered at each location in loc. + * @param channels Input candidate features + * @param loc Input locations + * @param features Output features at each 16 x 16 window centered at loc + */ void ViewAsWindows(const CubeType& channels, const arma::umat& loc, CubeType& features); + /** + * Regular Features are features directly indexing to the image locations. + * Each channel is of same size as that of image and captures a different + * faucet of information, like color, gradient and oriented gradient + * information at a patch x. Refer to the below paper for details: + * Sketch Tokens: A Learned Mid-level Representation for Contour + * and Object Detection. + * Published in: Computer Vision and Pattern Recognition (CVPR), + * 2013 IEEE Conference + * Author(s) : Joseph J. Lim ; C. Lawrence Zitnick ; Piotr Dollár + * @param channels input candidate regular features + * @param loc input locations + * @param RegFtr output Regular Features at that location. + */ void GetRegFtr(const CubeType& channels, const arma::umat& loc, CubeType& RegFtr); + /** + * The self-similarity features capture the portions + * of an image patch that contain similar textures based + * on color or gradient information. + * Refer the sketch tokens paper for details. + * @param channels input candidate self similarity features + * @param loc input locations + * @param RegFtr output Self Similarity Features at that location. + */ void GetSSFtr(const CubeType& channels, const arma::umat& loc, - CubeType SSFtr); + CubeType& SSFtr); + /** + * Rearranges the features so that we can use them + * in a fast and efficient way. + * convert (16, 16, 13 * 1000) features to (256, 1000, 13). + * @param channels input features + * @param ch output features just rearranged + */ void Rearrange(const CubeType& channels, CubeType& ch); + /** + * Compute the pairwise differences between n-dimensional points in a way + * specified in the paper "Structured Forests for Fast Edge Detection". + * Note: Indeed this is not a valid distance measurement (asymmetry). + * find nC2 differences, for locations in the grid_pos. + * @param features n-dimensional points + * @param grid_pos locations + * @param Output stores differences between features at grid_pos locations + */ void PDist(const CubeType& features, const arma::uvec& grid_pos, CubeType& Output); + /** + * Find the index of minimum element in a vector + * @param k Vector in which we want to find element + * return value index of minimum element + */ size_t IndexMin(arma::vec& k); size_t Discretize(const MatType& labels, const size_t nClass,\ From 72eb1ef22c7db1ea33af3de1cd043cdb277ec562 Mon Sep 17 00:00:00 2001 From: nilayjain Date: Sun, 14 Aug 2016 06:20:29 +0000 Subject: [PATCH 14/14] incorporated fixes, added comments in PrepareData --- src/mlpack/methods/edge_boxes/CMakeLists.txt | 1 + .../methods/edge_boxes/edge_boxes_main.cpp | 8 +- .../methods/edge_boxes/feature_extraction.hpp | 15 +-- .../edge_boxes/feature_extraction_impl.hpp | 119 +++++++++--------- .../methods/edge_boxes/feature_parameters.hpp | 11 +- 5 files changed, 77 insertions(+), 77 deletions(-) diff --git a/src/mlpack/methods/edge_boxes/CMakeLists.txt b/src/mlpack/methods/edge_boxes/CMakeLists.txt index ce7bdea79ab..72345bd15e2 100644 --- a/src/mlpack/methods/edge_boxes/CMakeLists.txt +++ b/src/mlpack/methods/edge_boxes/CMakeLists.txt @@ -4,6 +4,7 @@ cmake_minimum_required(VERSION 2.8) # Define the files we need to compile. # Anything not in this list will not be compiled into mlpack. set(SOURCES + feature_parameters.hpp feature_extraction.hpp feature_extraction_impl.hpp ) diff --git a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp index 9d9fcb25265..464f2240cc8 100644 --- a/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp +++ b/src/mlpack/methods/edge_boxes/edge_boxes_main.cpp @@ -9,7 +9,6 @@ using namespace mlpack; using namespace mlpack::structured_tree; -using namespace std; int main(int argc, char** argv) { @@ -45,7 +44,7 @@ int main(int argc, char** argv) nms: if true apply non-maximum suppression to edges */ - FeatureParameters params = FeatureParameters(); + FeatureParameters params; params.NumImages(2); params.RowSize(321); params.ColSize(481); @@ -65,8 +64,6 @@ int main(int argc, char** argv) params.NumCell(5); params.NumTree(8); StructuredForests SF(params); -// arma::uvec x(2); - //SF.GetFeatureDimension(x); arma::mat segmentations, boundaries, images; data::Load("/home/nilay/example/small_images.csv", images); @@ -74,9 +71,10 @@ int main(int argc, char** argv) data::Load("/home/nilay/example/small_segmentation_1.csv", segmentations); SF.PrepareData(images, boundaries, segmentations); - cout << "PrepareData done." << endl; + std::cout << "PrepareData done." << std::endl; return 0; } + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction.hpp b/src/mlpack/methods/edge_boxes/feature_extraction.hpp index 96f3a19e70a..39a7579737b 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction.hpp @@ -26,7 +26,7 @@ class StructuredForests */ StructuredForests(FeatureParameters F); - void PrepareData(const MatType& Images, const MatType& Boundaries,\ + void PrepareData(const MatType& Images, const MatType& Boundaries, const MatType& Segmentations); /** @@ -42,7 +42,7 @@ class StructuredForests * @param inf a large double value. * @param d Output vector which stores distance transform of f. */ - void DistanceTransform1D(const arma::vec& f, const size_t n,\ + void DistanceTransform1D(const arma::vec& f, const size_t n, const double inf, arma::vec& d); /** @@ -72,8 +72,8 @@ class StructuredForests * @param SSFtr Output the Self Similarity Features * @param table a helper vector required to convert image to LAB space. */ - void GetFeatures(const MatType &Image, arma::umat &loc,\ - CubeType& RegFtr, CubeType& SSFtr,\ + void GetFeatures(const MatType &Image, arma::umat &loc, + CubeType& RegFtr, CubeType& SSFtr, const arma::vec& table); /** @@ -102,7 +102,7 @@ class StructuredForests * @param ssCh Channels used in calculating Self Similarity Features. * @param table a vector which helps in converting image from RGB to LUV. */ - void GetShrunkChannels(const CubeType& InImage, CubeType ®_ch,\ + void GetShrunkChannels(const CubeType& InImage, CubeType ®_ch, CubeType &ss_ch, const arma::vec& table); /** @@ -111,7 +111,7 @@ class StructuredForests * @param InImage Input Image in RGB color space. * @param OutImage Ouptut Image in LUV color space. */ - void RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ + void RGB2LUV(const CubeType& InImage, CubeType& OutImage, const arma::vec& table); /** @@ -247,7 +247,7 @@ class StructuredForests */ size_t IndexMin(arma::vec& k); - size_t Discretize(const MatType& labels, const size_t nClass,\ + size_t Discretize(const MatType& labels, const size_t nClass, const size_t nSample, arma::vec& DiscreteLabels); }; @@ -259,3 +259,4 @@ class StructuredForests + diff --git a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp index 5edbcddc6e0..4610910066b 100644 --- a/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp +++ b/src/mlpack/methods/edge_boxes/feature_extraction_impl.hpp @@ -4,21 +4,20 @@ * * Implementation of feature extraction methods. */ -#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_ImPL_HPP -#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_ImPL_HPP +#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP +#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP #include "feature_extraction.hpp" #include -using namespace mlpack::ann; namespace mlpack { namespace structured_tree { template StructuredForests:: -StructuredForests(FeatureParameters F) : params(F) {} +StructuredForests(mlpack::structured_tree::FeatureParameters F) : params(F) {} template @@ -157,21 +156,17 @@ CopyMakeBorder(const CubeType& InImage, size_t top, template void StructuredForests:: -RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ +RGB2LUV(const CubeType& InImage, CubeType& OutImage, const arma::vec& table) { - MatType rgb2xyz; - rgb2xyz << 0.430574 << 0.222015 << 0.020183 << arma::endr - << 0.341550 << 0.706655 << 0.129553 << arma::endr - << 0.178325 << 0.071330 << 0.939180; - CubeType xyz(InImage.n_rows, InImage.n_cols, rgb2xyz.n_cols); + CubeType xyz(InImage.n_rows, InImage.n_cols, 3); - xyz.slice(0) = 0.430574 * InImage.slice(0) + 0.341550 * InImage.slice(1)\ + xyz.slice(0) = 0.430574 * InImage.slice(0) + 0.341550 * InImage.slice(1) + 0.178325 * InImage.slice(2); - xyz.slice(1) = 0.222015 * InImage.slice(0) + 0.706655 * InImage.slice(1)\ + xyz.slice(1) = 0.222015 * InImage.slice(0) + 0.706655 * InImage.slice(1) + 0.071330 * InImage.slice(2); - xyz.slice(2) = 0.020183 * InImage.slice(0) + 0.129553 * InImage.slice(1)\ + xyz.slice(2) = 0.020183 * InImage.slice(0) + 0.129553 * InImage.slice(1) + 0.939180 * InImage.slice(2); MatType nz(InImage.n_rows, InImage.n_cols); @@ -186,9 +181,9 @@ RGB2LUV(const CubeType& InImage, CubeType& OutImage,\ } } double maxi = 1.0 / 270.0; - OutImage.slice(1) = OutImage.slice(0) % (13 * 4 * (xyz.slice(0) % nz) \ + OutImage.slice(1) = OutImage.slice(0) % (13 * 4 * (xyz.slice(0) % nz) - 13 * 0.197833) + 88 * maxi; - OutImage.slice(2) = OutImage.slice(0) % (13 * 9 * (xyz.slice(1) % nz) \ + OutImage.slice(2) = OutImage.slice(0) % (13 * 9 * (xyz.slice(1) % nz) - 13 * 0.468331) + 134 * maxi; } @@ -244,8 +239,8 @@ Convolution(CubeType &InOutImage, const MatType& Filter, const size_t radius) for(size_t i = radius; i < OutImage.n_rows - radius; ++i) { InOutImage(i - radius, j - radius, k) = - arma::accu(OutImage.slice(k)\ - .submat(i - radius, j - radius,\ + arma::accu(OutImage.slice(k) + .submat(i - radius, j - radius, i + radius, j + radius) % Filter); } } @@ -260,7 +255,7 @@ ConvTriangle(CubeType &InImage, const size_t radius) { if (radius == 0) { - // nothing to do + return; } else if (radius <= 1) { @@ -300,7 +295,6 @@ MaxAndLoc(CubeType& mag, arma::umat& Location, MatType& MaxVal) const { for(size_t j = 0; j < mag.n_cols; ++j) { - /*can use -infinity here*/ double max = -DBL_MAX; for(size_t k = 0; k < mag.n_slices; ++k) { @@ -346,7 +340,7 @@ Gradient(const CubeType& InImage, CubeType mag(InImage.n_rows, InImage.n_cols, InImage.n_slices); for (size_t i = 0; i < InImage.n_slices; ++i) { - mag.slice(i) = arma::sqrt( arma::square \ + mag.slice(i) = arma::sqrt( arma::square ( dx.slice(i) + arma::square( dy.slice(i) ) ) ); } @@ -361,7 +355,7 @@ Gradient(const CubeType& InImage, Magnitude = Magnitude / (mag2.slice(0) + 0.01); } - MatType dx_mat(dx.n_rows, dx.n_cols),\ + MatType dx_mat(dx.n_rows, dx.n_cols), dy_mat(dy.n_rows, dy.n_cols); for(size_t j = 0; j < InImage.n_cols; ++j) @@ -376,8 +370,8 @@ Gradient(const CubeType& InImage, // calculate Orientation of edges. Orientation = arma::atan(dy_mat / dx_mat); - Orientation.transform( [](double val)\ - { if(val < 0) return (val + arma::datum::pi);\ + Orientation.transform( [](double val) + { if(val < 0) return (val + arma::datum::pi); else return (val);} ); for(size_t j = 0; j < InImage.n_cols; ++j) @@ -449,7 +443,7 @@ Histogram(const MatType& Magnitude, template void StructuredForests:: -GetShrunkChannels(const CubeType& InImage, CubeType& reg_ch,\ +GetShrunkChannels(const CubeType& InImage, CubeType& reg_ch, CubeType& ss_ch, const arma::vec& table) { CubeType luv; @@ -493,7 +487,7 @@ GetShrunkChannels(const CubeType& InImage, CubeType& reg_ch,\ BilinearInterpolation( Magnitude, rsize, csize, channels.slice(slice_idx)); slice_idx++; for(size_t i = 0; i < InImage.n_slices; ++i) - BilinearInterpolation( Magnitude, rsize, csize,\ + BilinearInterpolation( Magnitude, rsize, csize, channels.slice(i + slice_idx)); slice_idx += 3; scale -= 0.5; @@ -542,7 +536,7 @@ ViewAsWindows(const CubeType& channels, const arma::umat& loc, size_t y = loc(i, 1); /*(x,y) in channels, is ((x+p), (y+p)) in incCh*/ - CubeType patch = incCh.tube((x + p) - p, (y + p) - p,\ + CubeType patch = incCh.tube((x + p) - p, (y + p) - p, (x + p) + p - 1, (y + p) + p - 1); // since each patch has 13 channel we have to increase the index by 13 features.slices(channel, channel + 12) = patch; @@ -607,11 +601,11 @@ PDist(const CubeType& features, const arma::uvec& grid_pos, } } -//returns 300,1000,13 dImension features. +//returns (300, 1000, 13) dimension features. template void StructuredForests:: GetSSFtr(const CubeType& channels, const arma::umat& loc, - CubeType SSFtr) + CubeType& SSFtr) { const size_t shrink = this->params.Shrink(); const size_t pSize = this->params.PSize() / shrink; @@ -623,7 +617,7 @@ GetSSFtr(const CubeType& channels, const arma::umat& loc, arma::uvec g_pos(numCell); for(size_t i = 0; i < numCell; ++i) { - g_pos(i) = (size_t)round( (i + 1) * (pSize + 2 * half_cell_size \ + g_pos(i) = (size_t)round( (i + 1) * (pSize + 2 * half_cell_size - 1) / (numCell + 1.0) - half_cell_size); } arma::uvec grid_pos(numCell * numCell); @@ -653,13 +647,12 @@ GetFeatures(const MatType &Image, arma::umat &loc, const size_t colSize = this->params.ColSize(); const size_t bottom = (4 - (Image.n_rows / 3) % 4) % 4; const size_t right = (4 - Image.n_cols % 4) % 4; - //cout << "Botttom = " << bottom << " right = " << right << endl; CubeType InImage(Image.n_rows / 3, Image.n_cols, 3); for(size_t i = 0; i < 3; ++i) { - InImage.slice(i) = Image.submat(i * rowSize, 0, \ + InImage.slice(i) = Image.submat(i * rowSize, 0, (i + 1) * rowSize - 1, colSize - 1); } @@ -672,14 +665,11 @@ GetFeatures(const MatType &Image, arma::umat &loc, const size_t rsize = OutImage.n_rows / shrink; const size_t csize = OutImage.n_cols / shrink; - /* this part gives double free or corruption Out error - when executed for a second tIme */ CubeType reg_ch = CubeType(rsize, csize, numChannels); CubeType ss_ch = CubeType(rsize, csize, numChannels); this->GetShrunkChannels(InImage, reg_ch, ss_ch, table); - loc /= shrink; this->GetRegFtr(reg_ch, loc, RegFtr); this->GetSSFtr(ss_ch, loc, SSFtr); @@ -693,9 +683,10 @@ GetFeatures(const MatType &Image, arma::umat &loc, template void StructuredForests:: -PrepareData(const MatType& Images, const MatType& Boundaries,\ +PrepareData(const MatType& Images, const MatType& Boundaries, const MatType& Segmentations) { + // create temporary variables. const size_t numImages = this->params.NumImages(); const size_t numTree = this->params.NumTree(); const size_t numPos = this->params.NumPos(); @@ -709,25 +700,26 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ // gRad = radius of ground truth patches. const size_t pRad = pSize / 2, gRad = gSize / 2; arma::vec FtrDim; + // get the dimensions of the features. this->GetFeatureDimension(FtrDim); const size_t nFtrDim = FtrDim(0) + FtrDim(1); + // we only keep a fraction of features. const size_t nSmpFtrDim = (size_t)(nFtrDim * fraction); - size_t time=0; for(size_t i = 0; i < numTree; ++i) { - //Implement the logic for if data already exists. + // Implement the logic for if data already exists. + // this is our new feature dimension. MatType ftrs = arma::zeros(numPos + numNeg, nSmpFtrDim); - //effectively a 3d array. . . + // effectively a 3d array. MatType lbls = arma::zeros((numPos + numNeg ), gSize * gSize); // still to be done: store features and labels calculated // in the loop and store it in these Matrices. - // Could use some suggestions for this. - + size_t loop_iter = numImages; - // a vector which helps in converting Image from RGB2LUV. + // table is a vector which helps in converting Image from RGB2LUV. double a, y0, maxi; a = std::pow(29.0, 3) / 27.0; y0 = 8.0 / a; @@ -750,18 +742,22 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ table(i) = table(i - 1); size_t col_i = 0, col_s = 0, col_b = 0; + // process data of each image one by one. for(size_t j = 0; j < loop_iter; ++j) { + // these varaibles store image, boundaries and segmentation information + // for each image in our dataset. MatType Img, bnds, segs; + Img = MatType(Images.colptr(col_i), colSize, rowSize * 3).t() / 255; col_i += 3; - //Img = Images.submat((j * 3) * rowSize, 0, ((j * 3) + 3) * rowSize - 1, colSize - 1); bnds = MatType(Boundaries.colptr(col_b), colSize, rowSize).t(); col_b++; segs = MatType(Segmentations.colptr(col_s), colSize, rowSize).t(); col_s++; + MatType mask(rowSize, colSize, arma::fill::ones); mask.col(pRad - 1).fill(0); mask.row( (mask.n_rows - 1) - (pRad - 1) ).fill(0); @@ -770,35 +766,38 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ mask.n_cols - 1).fill(0); // number of positive or negative patches per ground truth. - const size_t nPatchesPerGt = 500; + MatType dis; + + // calculate distance transform of image boundary. this->DistanceTransformImage(bnds, 1, dis); + // take square root for euclidean distance transform. dis = arma::sqrt(dis); + + // find positive and negative edge locations using mask. arma::uvec posLoc = arma::find( (dis < gRad) % mask ); arma::uvec negLoc = arma::find( (dis >= gRad) % mask ); + // we take a random permutation of posLoc and negLoc. posLoc = arma::shuffle(posLoc); negLoc = arma::shuffle(negLoc); - size_t lenLoc = std::min((int) negLoc.n_elem, std::min((int) nPatchesPerGt,\ + size_t lenLoc = std::min((int) negLoc.n_elem, std::min((int) nPatchesPerGt, (int) posLoc.n_elem)); arma::umat loc(lenLoc * 2, 2); for(size_t i = 0; i < lenLoc; ++i) - { loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), posLoc(i) ).t(); - //cout << "posLoc: " << loc(i, 0) << ", " << loc(i, 1) << endl; - } - + for(size_t i = lenLoc; i < 2 * lenLoc; ++i) - { loc.row(i) = arma::ind2sub(arma::size(dis.n_rows, dis.n_cols), negLoc(i - lenLoc) ).t(); - } CubeType SSFtr, RegFtr; Timer::Start("get_features"); + // calculate the regular and self similarity features of + // the image Img at locations loc. this->GetFeatures(Img, loc, RegFtr, SSFtr, table); Timer::Stop("get_features"); @@ -816,18 +815,20 @@ PrepareData(const MatType& Images, const MatType& Boundaries,\ // have to do this or we can overload the CopyMakeBorder to support MatType. s.slice(0) = segs; CubeType in_segs; - this->CopyMakeBorder(s, gRad, gRad, gRad,\ + // add a padding around the segments. + this->CopyMakeBorder(s, gRad, gRad, gRad, gRad, in_segs); for(size_t i = 0; i < loc.n_rows; ++i) { size_t x = loc(i, 0); size_t y = loc(i, 1); - //std::cout << "x, y = " << x << " " << y << std::endl; - lbls.row(i) = arma::vectorise(in_segs.slice(0)\ - .submat((x + gRad) - gRad, (y + gRad) - gRad,\ + // stores the segments window wise into matrix lbls. + lbls.row(i) = arma::vectorise(in_segs.slice(0) + .submat((x + gRad) - gRad, (y + gRad) - gRad, (x + gRad) + gRad - 1, (y + gRad) + gRad - 1)).t(); } } + // calculates the discrete labels from segments. arma::vec DiscreteLabels; size_t x = Discretize(lbls, 2, 256, DiscreteLabels); } @@ -852,7 +853,7 @@ IndexMin(arma::vec& k) // label to discreet classes in matrix subLbls. (this is a vector if nClass = 2) template size_t StructuredForests:: -Discretize(const MatType& labels, const size_t nClass,\ +Discretize(const MatType& labels, const size_t nClass, const size_t nSample, arma::vec& DiscreteLabels) { // Map labels to discrete class labels. @@ -867,13 +868,13 @@ Discretize(const MatType& labels, const size_t nClass,\ MatType zs(labels.n_rows, nSample); // no. of principal components to keep. - size_t dim = std::min( 5, std::min( (int)nSample,\ + size_t dim = std::min( 5, std::min( (int)nSample, (int)std::floor( std::log2( (int)nClass ) ) ) ); DiscreteLabels = arma::zeros(labels.n_rows, dim); + arma::uvec z1 = arma::shuffle(lis1); + arma::uvec z2 = arma::shuffle(lis1); for (size_t j = 0; j < zs.n_cols; ++j) { - arma::uvec z1 = arma::shuffle(lis1); - arma::uvec z2 = arma::shuffle(lis1); for (size_t i = 0; i < zs.n_rows; ++i) zs(i, j) = (labels(i, z1(j)) == labels(i, z2(j))) ? 1 : 0; } @@ -902,7 +903,6 @@ Discretize(const MatType& labels, const size_t nClass,\ // we take only first row in transformedData (256 * 20000) as dim = 1. Timer::Stop("pca_timer"); Timer::Start("other_discretize"); - //std::cout << Timer::Get("pca_timer") << std::endl; DiscreteLabels = arma::conv_to::from(transformedData.row(0).t() > 0); Timer::Stop("other_discretize"); } @@ -913,3 +913,4 @@ Discretize(const MatType& labels, const size_t nClass,\ } // namespace mlpack #endif + diff --git a/src/mlpack/methods/edge_boxes/feature_parameters.hpp b/src/mlpack/methods/edge_boxes/feature_parameters.hpp index 28b57ba97f9..8b0527bf675 100644 --- a/src/mlpack/methods/edge_boxes/feature_parameters.hpp +++ b/src/mlpack/methods/edge_boxes/feature_parameters.hpp @@ -5,8 +5,8 @@ * Implementation of feature parameter class. */ -#ifndef MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP -#define MLPACK_METHODS_EDGE_BOXES_STRUCTURED_TREE_IMPL_HPP +#ifndef MLPACK_METHODS_EDGE_BOXES_FEATURE_PARAMETERS_HPP +#define MLPACK_METHODS_EDGE_BOXES_FEATURE_PARAMETERS_HPP namespace mlpack { namespace structured_tree { @@ -16,8 +16,6 @@ class FeatureParameters { public: - FeatureParameters(){} //default constructor - //! getter and setter methods for all the fields in class. void NumImages(size_t value) { numImages = value; } size_t NumImages() const { return numImages; } @@ -123,7 +121,8 @@ class FeatureParameters size_t numTree; }; -} -} +} // namespace structured_tree +} // namespace mlpack #include "feature_extraction.hpp" #endif +