Skip to content

Commit

Permalink
Removing extra test
Browse files Browse the repository at this point in the history
  • Loading branch information
geekypathak21 committed Jun 12, 2020
1 parent 6ee4b3a commit 47b544c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 84 deletions.
4 changes: 3 additions & 1 deletion src/mlpack/methods/ann/layer/radial_basis_function_impl.hpp
Expand Up @@ -21,7 +21,9 @@ template<typename InputDataType, typename OutputDataType,
typename Activation>
RBF<InputDataType, OutputDataType, Activation>::RBF() :
inSize(0),
outSize(0)
outSize(0),
sigmas(0),
betas(0),
{
// Nothing to do here.
}
Expand Down
83 changes: 0 additions & 83 deletions src/mlpack/tests/feedforward_network_test.cpp
Expand Up @@ -706,87 +706,4 @@ BOOST_AUTO_TEST_CASE(RBFNetworkTest)
TestNetwork<>(model1, dataset, labels1, dataset, labels, 10, 0.1);
}

/**
* Train the RBF network on a larger dataset.
*/
BOOST_AUTO_TEST_CASE(RBFNetworkTest)
{
mlpack::math::RandomSeed(std::time(NULL));
// Load the dataset.
arma::mat trainData;
data::Load("thyroid_train.csv", trainData, true);

arma::mat trainLabels = trainData.row(trainData.n_rows - 1);
trainData.shed_row(trainData.n_rows - 1);

arma::mat trainLabels1 = arma::zeros(3, trainData.n_cols);
for(size_t i = 0; i < trainData.n_cols; i++)
{
trainLabels1.col(i).row((trainLabels(i) - 1)) = 1;
}

arma::mat testData;
data::Load("thyroid_test.csv", testData, true);

arma::mat testLabels = testData.row(testData.n_rows - 1);
testData.shed_row(testData.n_rows - 1);

/*
* Construct a feed forward network with trainData.n_rows input nodes,
* hiddenLayerSize hidden nodes and trainLabels.n_rows output nodes. The
* network structure looks like:
*
* Input RBF Activation Output
* Layer Layer Layer Layer
* +-----+ +-----+ +-----+ +-----+
* | | | | | | | |
* | +------>| +------>| +------>| |
* | | | | | | | |
* +-----+ +--+--+ +-----+ +-----+
*/

arma::mat centroids;
KMeans<> kmeans;
kmeans.Cluster(trainData, 8, centroids);

FFN<MeanSquaredError<> > model;
model.Add<RBF<> >(trainData.n_rows, 8, centroids);
model.Add<Linear<> >(8, 3);

// RBFN neural net with MeanSquaredError.
TestNetwork<>(model, trainData, trainLabels1, testData, testLabels, 10, 0.1);

arma::mat dataset;
dataset.load("mnist_first250_training_4s_and_9s.arm");

// Normalize each point since these are images.
for (size_t i = 0; i < dataset.n_cols; ++i)
{
dataset.col(i) /= norm(dataset.col(i), 2);
}

arma::mat labels = arma::zeros(1, dataset.n_cols);
labels.submat(0, labels.n_cols / 2, 0, labels.n_cols - 1).fill(1);

arma::mat labels1 = arma::zeros(2, dataset.n_cols);
for(size_t i = 0; i < dataset.n_cols; i++)
{
labels1.col(i).row(labels(i)) = 1;
}
labels += 1;


arma::mat centroids1;
arma::Row<size_t> assignments;
KMeans<> kmeans1;
kmeans1.Cluster(dataset, 60, centroids1);

FFN<MeanSquaredError<> > model1;
model1.Add<RBF<> >(dataset.n_rows, 60, centroids1, 0.34493);
model1.Add<Linear<> >(60, 2);

// RBFN neural net with MeanSquaredError.
TestNetwork<>(model1, dataset, labels1, dataset, labels, 10, 0.1);
}

BOOST_AUTO_TEST_SUITE_END();

0 comments on commit 47b544c

Please sign in to comment.