From c6a8f87b47b372d6ef133ba17e79f72ecb509910 Mon Sep 17 00:00:00 2001 From: himanshupathak21061998 Date: Thu, 14 May 2020 22:00:06 +0530 Subject: [PATCH] Adding changes in implementation --- src/mlpack/methods/ann/layer/layer_types.hpp | 1 - .../ann/layer/radial_basis_function_impl.hpp | 22 +++-- src/mlpack/tests/feedforward_network_test.cpp | 90 +++++++++---------- 3 files changed, 60 insertions(+), 53 deletions(-) diff --git a/src/mlpack/methods/ann/layer/layer_types.hpp b/src/mlpack/methods/ann/layer/layer_types.hpp index abe1e461c95..18c61c64ae3 100644 --- a/src/mlpack/methods/ann/layer/layer_types.hpp +++ b/src/mlpack/methods/ann/layer/layer_types.hpp @@ -252,7 +252,6 @@ using LayerTypes = boost::variant< LeakyReLU*, CReLU*, Linear*, - RBF*, LinearNoBias*, LogSoftMax*, Lookup*, diff --git a/src/mlpack/methods/ann/layer/radial_basis_function_impl.hpp b/src/mlpack/methods/ann/layer/radial_basis_function_impl.hpp index 71b6d7e3f40..85f218d0d68 100644 --- a/src/mlpack/methods/ann/layer/radial_basis_function_impl.hpp +++ b/src/mlpack/methods/ann/layer/radial_basis_function_impl.hpp @@ -41,19 +41,27 @@ void RBF::Forward( const arma::Mat& input, arma::Mat& output) { - centres = arma::mat(outSize, input.n_rows, arma::fill::randu); + centres = arma::mat(input.n_rows, outSize, arma::fill::randu); centres = arma::normcdf(centres, 0, 1); - sigmas = arma::ones(1, outSize); + sigmas = arma::ones(outSize, 1); + arma::cube x = arma::cube(input.n_rows, outSize, input.n_cols); - distances = arma::mat(outSize, input.n_cols); + for (size_t i = 0; i < input.n_cols; i++) + { + x.slice(i).each_col() = input.col(i); + } + + arma::cube c = arma::cube(input.n_rows, outSize, input.n_cols); + c.each_slice() = centres; + + distances = arma::mat(input.n_rows, input.n_cols); - for (size_t i = 0; i < outSize; i++) + for (size_t i = 0; i < input.n_cols; i++) { - arma::mat temp = centres.each_col(i) distances.col(i) = arma::pow(arma::sum( arma::pow(( - temp), - 2), 1), 0.5).t() * sigmas(i); + x.slice(i) - c.slice(i)), + 2), 1), 0.5); } output = distances; diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp index 4f2b3068b51..bd7dbc7a51c 100644 --- a/src/mlpack/tests/feedforward_network_test.cpp +++ b/src/mlpack/tests/feedforward_network_test.cpp @@ -624,20 +624,20 @@ BOOST_AUTO_TEST_CASE(OptimizerTest) /** * Train the RBF network on a larger dataset. */ -//BOOST_AUTO_TEST_CASE(RBFNetworkTest) -//{ -// // Load the dataset. -// arma::mat trainData; -// data::Load("thyroid_train.csv", trainData, true); -// -// arma::mat trainLabels = trainData.row(trainData.n_rows - 1); -// trainData.shed_row(trainData.n_rows - 1); -// -// arma::mat testData; -// data::Load("thyroid_test.csv", testData, true); -// -// arma::mat testLabels = testData.row(testData.n_rows - 1); -// testData.shed_row(testData.n_rows - 1); +BOOST_AUTO_TEST_CASE(RBFNetworkTest) +{ + // Load the dataset. + arma::mat trainData; + data::Load("thyroid_train.csv", trainData, true); + + arma::mat trainLabels = trainData.row(trainData.n_rows - 1); + trainData.shed_row(trainData.n_rows - 1); + + arma::mat testData; + data::Load("thyroid_test.csv", testData, true); + + arma::mat testLabels = testData.row(testData.n_rows - 1); + testData.shed_row(testData.n_rows - 1); /* * Construct a feed forward network with trainData.n_rows input nodes, @@ -652,36 +652,36 @@ BOOST_AUTO_TEST_CASE(OptimizerTest) * | | | | | | | | * +-----+ +--+--+ +-----+ +-----+ */ -// -// FFN > model; -// model.Add >(trainData.n_cols, 8); -// model.Add >(); -// model.Add >(trainData.n_rows, 8); -// model.Add >(8, 3); -// model.Add >(); -// -// TestNetwork<>(model, trainData, trainLabels, testData, testLabels, 10, 0.1); -// arma::mat dataset; -// dataset.load("mnist_first250_training_4s_and_9s.arm"); -// -// // Normalize each point since these are images. -// for (size_t i = 0; i < dataset.n_cols; ++i) -// { -// dataset.col(i) /= norm(dataset.col(i), 2); -// } -// -// arma::mat labels = arma::zeros(1, dataset.n_cols); -// labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1); -// labels += 1; -// -// FFN > model1; -// model1.Add >(dataset.n_cols, 10); -// model1.Add >(); -// model1.Add >(dataset.n_rows, 10); -// model.Add >(10, 2); -// model1.Add >(); -// // Vanilla neural net with logistic activation function. -// TestNetwork<>(model1, dataset, labels, dataset, labels, 10, 0.2); -//} + + FFN > model; + model.Add >(trainData.n_cols, 8); + model.Add >(); + model.Add >(trainData.n_rows, 8); + model.Add >(8, 3); + model.Add >(); + + TestNetwork<>(model, trainData, trainLabels, testData, testLabels, 10, 0.1); + arma::mat dataset; + dataset.load("mnist_first250_training_4s_and_9s.arm"); + + // Normalize each point since these are images. + for (size_t i = 0; i < dataset.n_cols; ++i) + { + dataset.col(i) /= norm(dataset.col(i), 2); + } + + arma::mat labels = arma::zeros(1, dataset.n_cols); + labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1); + labels += 1; + + FFN > model1; + model1.Add >(dataset.n_cols, 10); + model1.Add >(); + model1.Add >(dataset.n_rows, 10); + model.Add >(10, 2); + model1.Add >(); + // Vanilla neural net with logistic activation function. + TestNetwork<>(model1, dataset, labels, dataset, labels, 10, 0.2); +} BOOST_AUTO_TEST_SUITE_END();