Skip to content
Permalink
Browse files

exitWithHelp() for ProtoNN

  • Loading branch information...
pushkalkatara committed Jul 23, 2019
1 parent ce96e90 commit 5144010ec138013835fa740281254d79f51227a2
Showing with 41 additions and 11 deletions.
  1. +7 −6 cpp/src/ProtoNN/ProtoNN.h
  2. +34 −5 cpp/src/ProtoNN/ProtoNNHyperParams.cpp
@@ -41,7 +41,7 @@ namespace EdgeML

//
// ProtoNNModel includes hyperparameters, parameters, and some state on initialization
//
//
class ProtoNNModel
{

@@ -71,6 +71,7 @@ namespace EdgeML

void setHyperParamsFromArgs(const int argc, const char** argv);
void finalizeHyperParams();
void exitWithHelp();

//
// Create a string with hyperParam settings
@@ -115,7 +116,7 @@ namespace EdgeML
class ProtoNNTrainer
{
////////////////////////////////////////////////////////
// DO NOT REORDER model and data.
// DO NOT REORDER model and data.
// They should be in this order for constructors to work
ProtoNNModel model;
Data data;
@@ -236,7 +237,7 @@ namespace EdgeML
void RBF();

void setFromArgs(const int argc, const char** argv);

void createOutputDirs();

public:
@@ -253,7 +254,7 @@ namespace EdgeML

~ProtoNNPredictor();

// Not thread safe
// Not thread safe
FP_TYPE testDenseDataPoint(
const FP_TYPE *const values,
const labelCount_t *const labels,
@@ -277,13 +278,13 @@ namespace EdgeML
dataCount_t batchSize);

ResultStruct testBatchWise();

ResultStruct testPointWise();

ResultStruct test();

void saveTopKScores(std::string filename="", int topk=5);

void normalize();
};
}
@@ -19,7 +19,7 @@ ProtoNNModel::ProtoNNHyperParams::ProtoNNHyperParams()

ntrain = 0;
nvalidation = 0;

iters = 20;
epochs = 20;
batchSize = 1024;
@@ -87,7 +87,7 @@ void ProtoNNModel::ProtoNNHyperParams::finalizeHyperParams()
assert(epochs >= 1 && "number of epochs should be >= 1");
// Following asserts removed to faciliate support for TLC
// which does not know how many datapoints are going to be fed before-hand!
// assert(ntrain >= 1);
// assert(ntrain >= 1);
// assert(nvalidation >= 0);
// assert(m <= ntrain);
if (d > D) {
@@ -112,7 +112,7 @@ void ProtoNNModel::ProtoNNHyperParams::setHyperParamsFromArgs(const int argc, co
{
for (int i = 1; i < argc; ++i) {
if (i % 2 == 1)
assert(argv[i][0] == '-'); //odd arguments must be specifiers, not values
assert(argv[i][0] == '-'); //odd arguments must be specifiers, not values
else {

switch (argv[i - 1][1]) {
@@ -143,7 +143,7 @@ void ProtoNNModel::ProtoNNHyperParams::setHyperParamsFromArgs(const int argc, co
if (argv[i][0] == '0') problemType = binary;
else if (argv[i][0] == '1') problemType = multiclass;
else if (argv[i][0] == '2') problemType = multilabel;
else assert(false); //Problem type unknown
else exitWithHelp();
break;
case 'W':
lambdaW = (FP_TYPE)strtod(argv[i], NULL);
@@ -193,11 +193,40 @@ void ProtoNNModel::ProtoNNHyperParams::setHyperParamsFromArgs(const int argc, co

default:
LOG_INFO("Command line argument not recognized; saw character: " + std::string(1, argv[i - 1][1]));
assert(false);
exitWithHelp();
break;
}
}
}

finalizeHyperParams();
}


void ProtoNNModel::ProtoNNHyperParams::exitWithHelp()
{
LOG_INFO("Options:");

LOG_INFO("-P : [Required] Option to load a predefined model, Visit docs for format. [Default: 0]");
LOG_INFO("-R : [Required] A random number seed which can be used to re-generate previously obtained experimental results. [Default: 42]");
LOG_INFO("-r : [Required] Number of training points.");
LOG_INFO("-v : [Required] Number of validation/test points.");
LOG_INFO("-D : [Required] The original dimension of the data.");
LOG_INFO("-l : [Required] Number of Classes");
LOG_INFO("-C : [Required] Problem Format. Specify one from 0 (binary), 1 (multiclass), 2 (multilabel)");
LOG_INFO("-d : [Required] Projection dimension (the dimension into which the data is projected). [Default: 15]");
LOG_INFO("-m : [m or k Required] Number of Prototypes. [Default: 20]");
LOG_INFO("-k : [m or k Required] Number of Prototypes Per Class.\n");

LOG_INFO("-g : [Optional] GammaNumerator, also alters RBF kernel parameter 𝛾 =(2.5⋅𝐺𝑎𝑚𝑚𝑎𝑁𝑢𝑚𝑒𝑟𝑎𝑡𝑜𝑟)/(𝑚𝑒𝑑𝑖𝑎𝑛(||𝐵𝑗,𝑊−𝑋𝑖||22)). [Default: 1.0] ");
LOG_INFO("-W : [Optional] Projection sparsity ( 𝜆𝑊 ). [Default: 1.0] ");
LOG_INFO("-Z : [Optional] Label Sparsity. [Default: 1.0]");
LOG_INFO("-B : [Optional] Prototype sparsity. [Default: 1.0]\n");


LOG_INFO("-T : [Optional] Total number of optimization iterations. [Default: 20]");
LOG_INFO("-E : [Optional] Number of epochs (complete see-through's) of the data for each iteration, and each parameter. [Default: 20]");
LOG_INFO("-N : [Optional] Normalization. Default: 0 (No Normalization), 1 (Min-Max Normalization), 2 (L2-Normalization)\n");

exit(1);
}

0 comments on commit 5144010

Please sign in to comment.
You can’t perform that action at this time.