Skip to content
Browse files

Merge pull request #2276 from adithya-tp/improved-ann-tutorial-ffn-th…


Improved ANN Tutorial for FFN - Thyroid Dataset Example
  • Loading branch information
rcurtin committed Mar 26, 2020
2 parents 8b773e9 + 2242743 commit 148ce4891f174223e6ab5efaae569679206b9b01
Showing with 50 additions and 16 deletions.
  1. +50 −16 doc/tutorials/ann/ann.txt
@@ -204,17 +204,17 @@ using namespace mlpack::ann;

int main()
// Load the training set.
arma::mat dataset;
data::Load("thyroid_train.csv", dataset, true);

// Split the labels from the training set.
arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
dataset.n_cols - 1);

// Split the data from the training set.
arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
dataset.n_rows - 1, dataset.n_cols - 1);
// Load the training set and testing set.
arma::mat trainData;
data::Load("thyroid_train.csv", trainData, true);
arma::mat testData;
data::Load("thyroid_test.csv", testData, true);

// Split the labels from the training set and testing set respectively.
arma::mat trainLabels = trainData.row(trainData.n_rows - 1);
arma::mat testLabels = testData.row(testData.n_rows - 1);
trainData.shed_row(trainData.n_rows - 1);
testData.shed_row(testData.n_rows - 1);

// Initialize the network.
FFN<> model;
@@ -226,14 +226,48 @@ int main()
// Train the model.
model.Train(trainData, trainLabels);

// Use the Predict method to get the assignments.
arma::mat assignments;
model.Predict(trainData, assignments);
// Use the Predict method to get the predictions.
arma::mat predictionTemp;
model.Predict(testData, predictionTemp);

Since the predictionsTemp is of dimensions (3 x number_of_data_points)
with continuous values, we first need to reduce it to a dimension of
(1 x number_of_data_points) with scalar values, to be able to compare with

The first step towards doing this is to create a matrix of zeros with the
desired dimensions (1 x number_of_data_points).

In predictionsTemp, the 3 dimensions for each data point correspond to the
probabilities of belonging to the three possible classes.
arma::mat prediction = arma::zeros<arma::mat>(1, predictionTemp.n_cols);

// Find index of max prediction for each data point and store in "prediction"
for (size_t i = 0; i < predictionTemp.n_cols; ++i)
// we add 1 to the max index, so that it matches the actual test labels.
prediction(i) = arma::as_scalar(arma::find(
arma::max(predictionTemp.col(i)) == predictionTemp.col(i), 1)) + 1;

Compute the error between predictions and testLabels,
now that we have the desired predictions.
size_t correct = arma::accu(prediction == testLabels);
double classificationError = 1 - double(correct) / testData.n_cols;

// Print out the classification error for the testing dataset.
std::cout << "Classification Error for the Test set: " << classificationError << std::endl;
return 0;

Now, the matrix assignments holds the classification of each point in the
Now, the matrix prediction holds the classification of each point in the
dataset. Subsequently, we find the classification error by comparing it
with testLabels.

In the next example, we create simple noisy sine sequences, which are trained
later on, using the RNN class in the `RNNModel()` method.

0 comments on commit 148ce48

Please sign in to comment.
You can’t perform that action at this time.