Skip to content

Commit

Permalink
add force_split functionality (#1310)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored and guolinke committed Apr 24, 2018
1 parent 71539cc commit 84fef71
Show file tree
Hide file tree
Showing 20 changed files with 1,428 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Expand Up @@ -144,7 +144,7 @@ if(USE_MPI)
include_directories(${MPI_CXX_INCLUDE_PATH})
endif(USE_MPI)

file(GLOB SOURCES
file(GLOB SOURCES
src/application/*.cpp
src/boosting/*.cpp
src/io/*.cpp
Expand Down
10 changes: 10 additions & 0 deletions docs/Parameters.rst
Expand Up @@ -520,6 +520,16 @@ IO Parameters

- separate by ``,`` for multi-validation data

- ``forced_splits``, default=\ ``""``, type=string

- path to a ``.json`` file that specifies splits to force at the top of every decision tree before best-first learning commences.

- ``.json`` file can be arbitrarily nested, and each split contains ``feature``, ``threshold`` fields, as well as ``left`` and ``right``
fields representing subsplits. Categorical splits are forced in a one-hot fashion, with ``left`` representing the split containing
the feature value and ``right`` representing other values.

- see ``examples/binary_classification/forced_splits.json`` as an example.

Objective Parameters
--------------------

Expand Down
12 changes: 12 additions & 0 deletions examples/binary_classification/forced_splits.json
@@ -0,0 +1,12 @@
{
"feature": 25,
"threshold": 1.30,
"left": {
"feature": 26,
"threshold": 0.85
},
"right": {
"feature": 26,
"threshold": 0.85
}
}
3 changes: 3 additions & 0 deletions examples/binary_classification/train.conf
Expand Up @@ -109,3 +109,6 @@ local_listen_port = 12400

# machines list file for parallel training, alias: mlist
machine_list_file = mlist.txt

# # force splits
# forced_splits = forced_splits.json
7 changes: 6 additions & 1 deletion include/LightGBM/config.h
Expand Up @@ -105,6 +105,7 @@ struct IOConfig: public ConfigBase {
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
std::string input_model = "";

int verbosity = 1;
int num_iteration_predict = -1;
bool is_pre_partition = false;
Expand Down Expand Up @@ -264,6 +265,9 @@ struct BoostingConfig: public ConfigBase {
std::string device_type = kDefaultDevice;
TreeConfig tree_config;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;

/* filename of forced splits */
std::string forcedsplits_filename = "";
};

/*! \brief Config for Network */
Expand Down Expand Up @@ -482,7 +486,8 @@ struct ParameterAlias {
"histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot",
"alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints", "max_delta_step"
"alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints", "max_delta_step",
"forced_splits"
});
std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) {
Expand Down
7 changes: 7 additions & 0 deletions include/LightGBM/dataset.h
Expand Up @@ -495,6 +495,13 @@ class Dataset {
return feature_groups_[group]->bin_mappers_[sub_feature]->BinToValue(threshold);
}

// given a real threshold, find the closest threshold bin
inline uint32_t BinThreshold(int i, double threshold_double) const {
const int group = feature2group_[i];
const int sub_feature = feature2subfeature_[i];
return feature_groups_[group]->bin_mappers_[sub_feature]->ValueToBin(threshold_double);
}

inline void CreateOrderedBins(std::vector<std::unique_ptr<OrderedBin>>* ordered_bins) const {
ordered_bins->resize(num_groups_);
OMP_INIT_EX();
Expand Down
232 changes: 232 additions & 0 deletions include/LightGBM/json11.hpp
@@ -0,0 +1,232 @@
/* json11
*
* json11 is a tiny JSON library for C++11, providing JSON parsing and serialization.
*
* The core object provided by the library is json11::Json. A Json object represents any JSON
* value: null, bool, number (int or double), string (std::string), array (std::vector), or
* object (std::map).
*
* Json objects act like values: they can be assigned, copied, moved, compared for equality or
* order, etc. There are also helper methods Json::dump, to serialize a Json to a string, and
* Json::parse (static) to parse a std::string as a Json object.
*
* Internally, the various types of Json object are represented by the JsonValue class
* hierarchy.
*
* A note on numbers - JSON specifies the syntax of number formatting but not its semantics,
* so some JSON implementations distinguish between integers and floating-point numbers, while
* some don't. In json11, we choose the latter. Because some JSON implementations (namely
* Javascript itself) treat all numbers as the same type, distinguishing the two leads
* to JSON that will be *silently* changed by a round-trip through those implementations.
* Dangerous! To avoid that risk, json11 stores all numbers as double internally, but also
* provides integer helpers.
*
* Fortunately, double-precision IEEE754 ('double') can precisely store any integer in the
* range +/-2^53, which includes every 'int' on most systems. (Timestamps often use int64
* or long long to avoid the Y2038K problem; a double storing microseconds since some epoch
* will be exact for +/- 275 years.)
*/

/* Copyright (c) 2013 Dropbox, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#pragma once

#include <string>
#include <vector>
#include <map>
#include <memory>
#include <initializer_list>

#ifdef _MSC_VER
#if _MSC_VER <= 1800 // VS 2013
#ifndef noexcept
#define noexcept throw()
#endif

#ifndef snprintf
#define snprintf _snprintf_s
#endif
#endif
#endif

namespace json11 {

enum JsonParse {
STANDARD, COMMENTS
};

class JsonValue;

class Json final {
public:
// Types
enum Type {
NUL, NUMBER, BOOL, STRING, ARRAY, OBJECT
};

// Array and object typedefs
typedef std::vector<Json> array;
typedef std::map<std::string, Json> object;

// Constructors for the various types of JSON value.
Json() noexcept; // NUL
Json(std::nullptr_t) noexcept; // NUL
Json(double value); // NUMBER
Json(int value); // NUMBER
Json(bool value); // BOOL
Json(const std::string &value); // STRING
Json(std::string &&value); // STRING
Json(const char * value); // STRING
Json(const array &values); // ARRAY
Json(array &&values); // ARRAY
Json(const object &values); // OBJECT
Json(object &&values); // OBJECT

// Implicit constructor: anything with a to_json() function.
template <class T, class = decltype(&T::to_json)>
Json(const T & t) : Json(t.to_json()) {}

// Implicit constructor: map-like objects (std::map, std::unordered_map, etc)
template <class M, typename std::enable_if<
std::is_constructible<std::string, decltype(std::declval<M>().begin()->first)>::value
&& std::is_constructible<Json, decltype(std::declval<M>().begin()->second)>::value,
int>::type = 0>
Json(const M & m) : Json(object(m.begin(), m.end())) {}

// Implicit constructor: vector-like objects (std::list, std::vector, std::set, etc)
template <class V, typename std::enable_if<
std::is_constructible<Json, decltype(*std::declval<V>().begin())>::value,
int>::type = 0>
Json(const V & v) : Json(array(v.begin(), v.end())) {}

// This prevents Json(some_pointer) from accidentally producing a bool. Use
// Json(bool(some_pointer)) if that behavior is desired.
Json(void *) = delete;

// Accessors
Type type() const;

bool is_null() const { return type() == NUL; }
bool is_number() const { return type() == NUMBER; }
bool is_bool() const { return type() == BOOL; }
bool is_string() const { return type() == STRING; }
bool is_array() const { return type() == ARRAY; }
bool is_object() const { return type() == OBJECT; }

// Return the enclosed value if this is a number, 0 otherwise. Note that json11 does not
// distinguish between integer and non-integer numbers - number_value() and int_value()
// can both be applied to a NUMBER-typed object.
double number_value() const;
int int_value() const;

// Return the enclosed value if this is a boolean, false otherwise.
bool bool_value() const;
// Return the enclosed string if this is a string, "" otherwise.
const std::string &string_value() const;
// Return the enclosed std::vector if this is an array, or an empty vector otherwise.
const array &array_items() const;
// Return the enclosed std::map if this is an object, or an empty map otherwise.
const object &object_items() const;

// Return a reference to arr[i] if this is an array, Json() otherwise.
const Json & operator[](size_t i) const;
// Return a reference to obj[key] if this is an object, Json() otherwise.
const Json & operator[](const std::string &key) const;

// Serialize.
void dump(std::string &out) const;
std::string dump() const {
std::string out;
dump(out);
return out;
}

// Parse. If parse fails, return Json() and assign an error message to err.
static Json parse(const std::string & in,
std::string & err,
JsonParse strategy = JsonParse::STANDARD);
static Json parse(const char * in,
std::string & err,
JsonParse strategy = JsonParse::STANDARD) {
if (in) {
return parse(std::string(in), err, strategy);
} else {
err = "null input";
return nullptr;
}
}
// Parse multiple objects, concatenated or separated by whitespace
static std::vector<Json> parse_multi(
const std::string & in,
std::string::size_type & parser_stop_pos,
std::string & err,
JsonParse strategy = JsonParse::STANDARD);

static inline std::vector<Json> parse_multi(
const std::string & in,
std::string & err,
JsonParse strategy = JsonParse::STANDARD) {
std::string::size_type parser_stop_pos;
return parse_multi(in, parser_stop_pos, err, strategy);
}

bool operator== (const Json &rhs) const;
bool operator< (const Json &rhs) const;
bool operator!= (const Json &rhs) const { return !(*this == rhs); }
bool operator<= (const Json &rhs) const { return !(rhs < *this); }
bool operator> (const Json &rhs) const { return (rhs < *this); }
bool operator>= (const Json &rhs) const { return !(*this < rhs); }

/* has_shape(types, err)
*
* Return true if this is a JSON object and, for each item in types, has a field of
* the given type. If not, return false and set err to a descriptive message.
*/
typedef std::initializer_list<std::pair<std::string, Type>> shape;
bool has_shape(const shape & types, std::string & err) const;

private:
std::shared_ptr<JsonValue> m_ptr;
};

// Internal class hierarchy - JsonValue objects are not exposed to users of this API.
class JsonValue {
protected:
friend class Json;
friend class JsonInt;
friend class JsonDouble;
virtual Json::Type type() const = 0;
virtual bool equals(const JsonValue * other) const = 0;
virtual bool less(const JsonValue * other) const = 0;
virtual void dump(std::string &out) const = 0;
virtual double number_value() const;
virtual int int_value() const;
virtual bool bool_value() const;
virtual const std::string &string_value() const;
virtual const Json::array &array_items() const;
virtual const Json &operator[](size_t i) const;
virtual const Json::object &object_items() const;
virtual const Json &operator[](const std::string &key) const;
virtual ~JsonValue() {}
};

} // namespace json11
6 changes: 5 additions & 1 deletion include/LightGBM/tree_learner.h
Expand Up @@ -4,9 +4,12 @@

#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include <LightGBM/json11.hpp>

#include <vector>

using namespace json11;

namespace LightGBM {

/*! \brief forward declaration */
Expand Down Expand Up @@ -44,7 +47,8 @@ class TreeLearner {
* \param is_constant_hessian True if all hessians share the same value
* \return A trained tree
*/
virtual Tree* Train(const score_t* gradients, const score_t* hessians, bool is_constant_hessian) = 0;
virtual Tree* Train(const score_t* gradients, const score_t* hessians, bool is_constant_hessian,
Json& forced_split_json) = 0;

/*!
* \brief use a existing tree to fit the new gradients and hessians.
Expand Down
3 changes: 2 additions & 1 deletion src/boosting/dart.hpp
Expand Up @@ -32,7 +32,8 @@ class DART: public GBDT {
* \param training_metrics Training metrics
* \param output_model_filename Filename of output model
*/
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void Init(const BoostingConfig* config, const Dataset* train_data,
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed);
Expand Down
13 changes: 11 additions & 2 deletions src/boosting/gbdt.cpp
Expand Up @@ -3,7 +3,6 @@
#include <LightGBM/utils/openmp_wrapper.h>

#include <LightGBM/utils/common.h>

#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
Expand Down Expand Up @@ -75,6 +74,16 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;

std::string forced_splits_path = config->forcedsplits_filename;
//load forced_splits file
if (forced_splits_path != "") {
std::ifstream forced_splits_file(forced_splits_path.c_str());
std::stringstream buffer;
buffer << forced_splits_file.rdbuf();
std::string err;
forced_splits_json_ = Json::parse(buffer.str(), err);
}

objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) {
Expand Down Expand Up @@ -425,7 +434,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
hess = hessians_.data() + bias;
}

new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_));
new_tree.reset(tree_learner_->Train(grad, hess, is_constant_hessian_, forced_splits_json_));
}

#ifdef TIMETAG
Expand Down

0 comments on commit 84fef71

Please sign in to comment.