Skip to content
Permalink
Browse files

Merge pull request #2150 from Hemal-Mamtora/nn_example_update

Doc: fixed variable name issue & added header files.
  • Loading branch information
zoq committed Jan 15, 2020
2 parents 7bc80bb + 557581e commit 4051e29dde76f4ed6d5132ea21df82808d9dabe0
Showing with 31 additions and 21 deletions.
  1. +31 −21 doc/tutorials/ann/ann.txt
@@ -195,31 +195,41 @@ number of features in the thyroid dataset and are just used as an abstract
representation.

@code
// Load the training set.
arma::mat dataset;
data::Load("thyroid_train.csv", dataset, true);
#include <mlpack/core.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/ffn.hpp>

// Split the labels from the training set.
arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
dataset.n_cols - 1);
using namespace mlpack;
using namespace mlpack::ann;

// Split the data from the training set.
arma::mat trainLabelsTemp = dataset.submat(dataset.n_rows - 3, 0,
dataset.n_rows - 1, dataset.n_cols - 1);

// Initialize the network.
FFN<> model;
model.Add<Linear<> >(trainData.n_rows, 8);
model.Add<SigmoidLayer<> >();
model.Add<Linear<> >(8, 3);
model.Add<LogSoftMax<> >();
int main()
{
// Load the training set.
arma::mat dataset;
data::Load("thyroid_train.csv", dataset, true);

// Split the labels from the training set.
arma::mat trainData = dataset.submat(0, 0, dataset.n_rows - 4,
dataset.n_cols - 1);

// Split the data from the training set.
arma::mat trainLabels = dataset.submat(dataset.n_rows - 3, 0,
dataset.n_rows - 1, dataset.n_cols - 1);

// Initialize the network.
FFN<> model;
model.Add<Linear<> >(trainData.n_rows, 8);
model.Add<SigmoidLayer<> >();
model.Add<Linear<> >(8, 3);
model.Add<LogSoftMax<> >();

// Train the model.
model.Train(trainData, trainLabels);
// Train the model.
model.Train(trainData, trainLabels);

// Use the Predict method to get the assignments.
arma::mat assignments;
model.Predict(trainData, assignments);
// Use the Predict method to get the assignments.
arma::mat assignments;
model.Predict(trainData, assignments);
}
@endcode

Now, the matrix assignments holds the classification of each point in the

0 comments on commit 4051e29

Please sign in to comment.
You can’t perform that action at this time.