Skip to content

Commit

Permalink
avoid forest pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed May 9, 2018
1 parent b90eed6 commit 4d8d818
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 72 deletions.
109 changes: 53 additions & 56 deletions cpp_version/src/main.cpp
@@ -1,18 +1,19 @@
/*------------------------------------------------------------------------------- /*-------------------------------------------------------------------------------
This file is part of ranger. This file is part of ranger.
Copyright (c) [2014-2018] [Marvin N. Wright] Copyright (c) [2014-2018] [Marvin N. Wright]
This software may be modified and distributed under the terms of the MIT license. This software may be modified and distributed under the terms of the MIT license.
Please note that the C++ core of ranger is distributed under MIT license and the Please note that the C++ core of ranger is distributed under MIT license and the
R package "ranger" under GPL3 license. R package "ranger" under GPL3 license.
#-------------------------------------------------------------------------------*/ #-------------------------------------------------------------------------------*/


#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <memory>


#include "globals.h" #include "globals.h"
#include "ArgumentHandler.h" #include "ArgumentHandler.h"
Expand All @@ -23,72 +24,68 @@ R package "ranger" under GPL3 license.


using namespace ranger; using namespace ranger;


void run_ranger(const ArgumentHandler& arg_handler, std::ostream& verbose_out) {
verbose_out << "Starting Ranger." << std::endl;

// Create forest object
std::unique_ptr<Forest> forest {};
switch (arg_handler.treetype) {
case TREE_CLASSIFICATION:
if (arg_handler.probability) {
forest = std::unique_ptr<Forest>(new ForestProbability);
} else {
forest = std::unique_ptr<Forest>(new ForestClassification);
}
break;
case TREE_REGRESSION:
forest = std::unique_ptr<Forest>(new ForestRegression);
break;
case TREE_SURVIVAL:
forest = std::unique_ptr<Forest>(new ForestSurvival);
break;
case TREE_PROBABILITY:
forest = std::unique_ptr<Forest>(new ForestProbability);
break;
}

// Call Ranger
forest->initCpp(arg_handler.depvarname, arg_handler.memmode, arg_handler.file, arg_handler.mtry,
arg_handler.outprefix, arg_handler.ntree, &verbose_out, arg_handler.seed, arg_handler.nthreads,
arg_handler.predict, arg_handler.impmeasure, arg_handler.targetpartitionsize, arg_handler.splitweights,
arg_handler.alwayssplitvars, arg_handler.statusvarname, arg_handler.replace, arg_handler.catvars,
arg_handler.savemem, arg_handler.splitrule, arg_handler.caseweights, arg_handler.predall, arg_handler.fraction,
arg_handler.alpha, arg_handler.minprop, arg_handler.holdout, arg_handler.predictiontype,
arg_handler.randomsplits);

forest->run(true);
if (arg_handler.write) {
forest->saveToFile();
}
forest->writeOutput();
verbose_out << "Finished Ranger." << std::endl;
}

int main(int argc, char **argv) { int main(int argc, char **argv) {


ArgumentHandler arg_handler(argc, argv);
Forest* forest = 0;
try { try {

// Handle command line arguments // Handle command line arguments
ArgumentHandler arg_handler(argc, argv);
if (arg_handler.processArguments() != 0) { if (arg_handler.processArguments() != 0) {
return 0; return 0;
} }
arg_handler.checkArguments(); arg_handler.checkArguments();


// Create forest object
switch (arg_handler.treetype) {
case TREE_CLASSIFICATION:
if (arg_handler.probability) {
forest = new ForestProbability;
} else {
forest = new ForestClassification;
}
break;
case TREE_REGRESSION:
forest = new ForestRegression;
break;
case TREE_SURVIVAL:
forest = new ForestSurvival;
break;
case TREE_PROBABILITY:
forest = new ForestProbability;
break;
}

// Verbose output to logfile if non-verbose mode
std::ostream* verbose_out;
if (arg_handler.verbose) { if (arg_handler.verbose) {
verbose_out = &std::cout; run_ranger(arg_handler, std::cout);
} else { } else {
std::ofstream* logfile = new std::ofstream(); std::ofstream logfile { arg_handler.outprefix + ".log" };
logfile->open(arg_handler.outprefix + ".log"); if (!logfile.good()) {
if (!logfile->good()) {
throw std::runtime_error("Could not write to logfile."); throw std::runtime_error("Could not write to logfile.");
} }
verbose_out = logfile; run_ranger(arg_handler, logfile);
}

// Call Ranger
*verbose_out << "Starting Ranger." << std::endl;
forest->initCpp(arg_handler.depvarname, arg_handler.memmode, arg_handler.file, arg_handler.mtry,
arg_handler.outprefix, arg_handler.ntree, verbose_out, arg_handler.seed, arg_handler.nthreads,
arg_handler.predict, arg_handler.impmeasure, arg_handler.targetpartitionsize, arg_handler.splitweights,
arg_handler.alwayssplitvars, arg_handler.statusvarname, arg_handler.replace, arg_handler.catvars,
arg_handler.savemem, arg_handler.splitrule, arg_handler.caseweights, arg_handler.predall, arg_handler.fraction,
arg_handler.alpha, arg_handler.minprop, arg_handler.holdout, arg_handler.predictiontype,
arg_handler.randomsplits);

forest->run(true);
if (arg_handler.write) {
forest->saveToFile();
} }
forest->writeOutput();
*verbose_out << "Finished Ranger." << std::endl;

delete forest;
} catch (std::exception& e) { } catch (std::exception& e) {
std::cerr << "Error: " << e.what() << " Ranger will EXIT now." << std::endl; std::cerr << "Error: " << e.what() << " Ranger will EXIT now." << std::endl;
delete forest;
return -1; return -1;
} }


Expand Down
2 changes: 1 addition & 1 deletion src/Data.h
Expand Up @@ -148,7 +148,7 @@ class Data {
return is_ordered_variable; return is_ordered_variable;
} }


void setIsOrderedVariable(std::vector<std::string>& unordered_variable_names) { void setIsOrderedVariable(const std::vector<std::string>& unordered_variable_names) {
is_ordered_variable.resize(num_cols, true); is_ordered_variable.resize(num_cols, true);
for (auto& variable_name : unordered_variable_names) { for (auto& variable_name : unordered_variable_names) {
size_t varID = getVariableID(variable_name); size_t varID = getVariableID(variable_name);
Expand Down
18 changes: 9 additions & 9 deletions src/Forest.cpp
Expand Up @@ -47,11 +47,11 @@ Forest::~Forest() {
void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode, std::string input_file, uint mtry, void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode, std::string input_file, uint mtry,
std::string output_prefix, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads, std::string output_prefix, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads,
std::string load_forest_filename, ImportanceMode importance_mode, uint min_node_size, std::string load_forest_filename, ImportanceMode importance_mode, uint min_node_size,
std::string split_select_weights_file, std::vector<std::string>& always_split_variable_names, std::string split_select_weights_file, const std::vector<std::string>& always_split_variable_names,
std::string status_variable_name, bool sample_with_replacement, std::vector<std::string>& unordered_variable_names, std::string status_variable_name, bool sample_with_replacement,
bool memory_saving_splitting, SplitRule splitrule, std::string case_weights_file, bool predict_all, const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
double sample_fraction, double alpha, double minprop, bool holdout, PredictionType prediction_type, std::string case_weights_file, bool predict_all, double sample_fraction, double alpha, double minprop, bool holdout,
uint num_random_splits) { PredictionType prediction_type, uint num_random_splits) {


this->verbose_out = verbose_out; this->verbose_out = verbose_out;


Expand Down Expand Up @@ -143,9 +143,9 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode


void Forest::initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees, void Forest::initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees,
std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size,
std::vector<std::vector<double>>& split_select_weights, std::vector<std::string>& always_split_variable_names, std::vector<std::vector<double>>& split_select_weights, const std::vector<std::string>& always_split_variable_names,
std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
std::vector<double>& case_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction, std::vector<double>& case_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits,
bool order_snps) { bool order_snps) {
Expand Down Expand Up @@ -183,7 +183,7 @@ void Forest::initR(std::string dependent_variable_name, Data* input_data, uint m
void Forest::init(std::string dependent_variable_name, MemoryMode memory_mode, Data* input_data, uint mtry, void Forest::init(std::string dependent_variable_name, MemoryMode memory_mode, Data* input_data, uint mtry,
std::string output_prefix, uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::string output_prefix, uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode,
uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
bool predict_all, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout, bool predict_all, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout,
PredictionType prediction_type, uint num_random_splits, bool order_snps) { PredictionType prediction_type, uint num_random_splits, bool order_snps) {


Expand Down Expand Up @@ -874,7 +874,7 @@ void Forest::setSplitWeightVector(std::vector<std::vector<double>>& split_select
} }
} }


void Forest::setAlwaysSplitVariables(std::vector<std::string>& always_split_variable_names) { void Forest::setAlwaysSplitVariables(const std::vector<std::string>& always_split_variable_names) {


deterministic_varIDs.reserve(num_independent_variables); deterministic_varIDs.reserve(num_independent_variables);


Expand Down
12 changes: 6 additions & 6 deletions src/Forest.h
Expand Up @@ -41,23 +41,23 @@ class Forest {
void initCpp(std::string dependent_variable_name, MemoryMode memory_mode, std::string input_file, uint mtry, void initCpp(std::string dependent_variable_name, MemoryMode memory_mode, std::string input_file, uint mtry,
std::string output_prefix, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads, std::string output_prefix, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads,
std::string load_forest_filename, ImportanceMode importance_mode, uint min_node_size, std::string load_forest_filename, ImportanceMode importance_mode, uint min_node_size,
std::string split_select_weights_file, std::vector<std::string>& always_split_variable_names, std::string split_select_weights_file, const std::vector<std::string>& always_split_variable_names,
std::string status_variable_name, bool sample_with_replacement, std::string status_variable_name, bool sample_with_replacement,
std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
std::string case_weights_file, bool predict_all, double sample_fraction, double alpha, double minprop, std::string case_weights_file, bool predict_all, double sample_fraction, double alpha, double minprop,
bool holdout, PredictionType prediction_type, uint num_random_splits); bool holdout, PredictionType prediction_type, uint num_random_splits);
void initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees, void initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees,
std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size,
std::vector<std::vector<double>>& split_select_weights, std::vector<std::string>& always_split_variable_names, std::vector<std::vector<double>>& split_select_weights, const std::vector<std::string>& always_split_variable_names,
std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
std::vector<double>& case_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction, std::vector<double>& case_weights, bool predict_all, bool keep_inbag, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits,
bool order_snps); bool order_snps);
void init(std::string dependent_variable_name, MemoryMode memory_mode, Data* input_data, uint mtry, void init(std::string dependent_variable_name, MemoryMode memory_mode, Data* input_data, uint mtry,
std::string output_prefix, uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::string output_prefix, uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode,
uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, uint min_node_size, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
bool predict_all, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout, bool predict_all, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout,
PredictionType prediction_type, uint num_random_splits, bool order_snps); PredictionType prediction_type, uint num_random_splits, bool order_snps);
virtual void initInternal(std::string status_variable_name) = 0; virtual void initInternal(std::string status_variable_name) = 0;
Expand Down Expand Up @@ -165,7 +165,7 @@ class Forest {


// Set split select weights and variables to be always considered for splitting // Set split select weights and variables to be always considered for splitting
void setSplitWeightVector(std::vector<std::vector<double>>& split_select_weights); void setSplitWeightVector(std::vector<std::vector<double>>& split_select_weights);
void setAlwaysSplitVariables(std::vector<std::string>& always_split_variable_names); void setAlwaysSplitVariables(const std::vector<std::string>& always_split_variable_names);


// Show progress every few seconds // Show progress every few seconds
#ifdef OLD_WIN_R_BUILD #ifdef OLD_WIN_R_BUILD
Expand Down

0 comments on commit 4d8d818

Please sign in to comment.