Skip to content

Commit

Permalink
Merge pull request #2276 from adithya-tp/improved-ann-tutorial-ffn-th…
Browse files Browse the repository at this point in the history
…yroid-example

Improved ANN Tutorial for FFN - Thyroid Dataset Example
  • Loading branch information
rcurtin committed Mar 26, 2020
2 parents 8b773e9 + 2242743 commit 148ce48
Showing 1 changed file with 50 additions and 16 deletions.
66 changes: 50 additions & 16 deletions doc/tutorials/ann/ann.txt
Expand Up @@ -204,17 +204,17 @@ using namespace mlpack::ann;


int main() int main()
{ {
// Load the training set. // Load the training set and testing set.
arma::mat dataset; arma::mat trainData;
data::Load("thyroid_train.csv", dataset, true); data::Load("thyroid_train.csv", trainData, true);

arma::mat testData;
// Split the labels from the training set. data::Load("thyroid_test.csv", testData, true);
arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
dataset.n_cols - 1); // Split the labels from the training set and testing set respectively.

arma::mat trainLabels = trainData.row(trainData.n_rows - 1);
// Split the data from the training set. arma::mat testLabels = testData.row(testData.n_rows - 1);
arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0, trainData.shed_row(trainData.n_rows - 1);
dataset.n_rows - 1, dataset.n_cols - 1); testData.shed_row(testData.n_rows - 1);


// Initialize the network. // Initialize the network.
FFN<> model; FFN<> model;
Expand All @@ -226,14 +226,48 @@ int main()
// Train the model. // Train the model.
model.Train(trainData, trainLabels); model.Train(trainData, trainLabels);


// Use the Predict method to get the assignments. // Use the Predict method to get the predictions.
arma::mat assignments; arma::mat predictionTemp;
model.Predict(trainData, assignments); model.Predict(testData, predictionTemp);

/*
Since the predictionsTemp is of dimensions (3 x number_of_data_points)
with continuous values, we first need to reduce it to a dimension of
(1 x number_of_data_points) with scalar values, to be able to compare with
testLabels.

The first step towards doing this is to create a matrix of zeros with the
desired dimensions (1 x number_of_data_points).

In predictionsTemp, the 3 dimensions for each data point correspond to the
probabilities of belonging to the three possible classes.
*/
arma::mat prediction = arma::zeros<arma::mat>(1, predictionTemp.n_cols);

// Find index of max prediction for each data point and store in "prediction"
for (size_t i = 0; i < predictionTemp.n_cols; ++i)
{
// we add 1 to the max index, so that it matches the actual test labels.
prediction(i) = arma::as_scalar(arma::find(
arma::max(predictionTemp.col(i)) == predictionTemp.col(i), 1)) + 1;
}

/*
Compute the error between predictions and testLabels,
now that we have the desired predictions.
*/
size_t correct = arma::accu(prediction == testLabels);
double classificationError = 1 - double(correct) / testData.n_cols;

// Print out the classification error for the testing dataset.
std::cout << "Classification Error for the Test set: " << classificationError << std::endl;
return 0;
} }
@endcode @endcode


Now, the matrix assignments holds the classification of each point in the Now, the matrix prediction holds the classification of each point in the
dataset. dataset. Subsequently, we find the classification error by comparing it
with testLabels.


In the next example, we create simple noisy sine sequences, which are trained In the next example, we create simple noisy sine sequences, which are trained
later on, using the RNN class in the `RNNModel()` method. later on, using the RNN class in the `RNNModel()` method.
Expand Down

0 comments on commit 148ce48

Please sign in to comment.