Skip to content

Commit

Permalink
Fix returning complex params in hyperparameter search, simplify code.,.
Browse files Browse the repository at this point in the history
Fix #1741, Fix #1833
  • Loading branch information
andrey-khropov committed Feb 2, 2024
1 parent 6fa51d9 commit 35b8928
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,15 @@ namespace {
// Adding quantization params if needed
if (gridParams.QuantizationParamsSet.GeneralInfo.IsBordersCountInGrid) {
const TString& paramName = gridParams.QuantizationParamsSet.GeneralInfo.BordersCountParamName;
namedOptionsCollection->IntOptions[paramName] = gridParams.QuantizationParamsSet.BinsCount;
namedOptionsCollection->BestParams[paramName] = NJson::TJsonValue(gridParams.QuantizationParamsSet.BinsCount);
}
if (gridParams.QuantizationParamsSet.GeneralInfo.IsBorderTypeInGrid) {
const TString& paramName = gridParams.QuantizationParamsSet.GeneralInfo.BorderTypeParamName;
namedOptionsCollection->StringOptions[paramName] = ToString(gridParams.QuantizationParamsSet.BorderType);
namedOptionsCollection->BestParams[paramName] = NJson::TJsonValue(ToString(gridParams.QuantizationParamsSet.BorderType));
}
if (gridParams.QuantizationParamsSet.GeneralInfo.IsNanModeInGrid) {
const TString& paramName = gridParams.QuantizationParamsSet.GeneralInfo.NanModeParamName;
namedOptionsCollection->StringOptions[paramName] = ToString(gridParams.QuantizationParamsSet.NanMode);
namedOptionsCollection->BestParams[paramName] = NJson::TJsonValue(ToString(gridParams.QuantizationParamsSet.NanMode));
}
}

Expand Down Expand Up @@ -1146,46 +1146,12 @@ namespace NCB {
void TBestOptionValuesWithCvResult::SetOptionsFromJson(
const THashMap<TString, NJson::TJsonValue>& options,
const TVector<TString>& optionsNames) {
BoolOptions.clear();
IntOptions.clear();
UIntOptions.clear();
DoubleOptions.clear();
StringOptions.clear();
ListOfDoublesOptions.clear();

BestParams = NJson::TJsonValue(NJson::JSON_MAP);
auto& bestParamsMap = BestParams.GetMapSafe();

for (const auto& optionName : optionsNames) {
const auto& option = options.at(optionName);
NJson::EJsonValueType type = option.GetType();
switch(type) {
case NJson::EJsonValueType::JSON_BOOLEAN: {
BoolOptions[optionName] = option.GetBoolean();
break;
}
case NJson::EJsonValueType::JSON_INTEGER: {
IntOptions[optionName] = option.GetInteger();
break;
}
case NJson::EJsonValueType::JSON_UINTEGER: {
UIntOptions[optionName] = option.GetUInteger();
break;
}
case NJson::EJsonValueType::JSON_DOUBLE: {
DoubleOptions[optionName] = option.GetDouble();
break;
}
case NJson::EJsonValueType::JSON_STRING: {
StringOptions[optionName] = option.GetString();
break;
}
case NJson::EJsonValueType::JSON_ARRAY: {
for (const auto& listElement : option.GetArray()) {
ListOfDoublesOptions[optionName].push_back(listElement.GetDouble());
}
break;
}
default: {
CB_ENSURE(false, "Error: option value should be bool, int, ui32, double, string or list of doubles");
}
}
bestParamsMap.emplace(optionName, options.at(optionName));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include <library/cpp/json/json_value.h>

#include <util/generic/hash.h>
#include <util/generic/fwd.h>
#include <util/generic/maybe.h>
#include <util/generic/string.h>
#include <util/generic/vector.h>
Expand All @@ -32,13 +32,12 @@ namespace NCB {
struct TBestOptionValuesWithCvResult {
public:
TVector<TCVResult> CvResult;
THashMap<TString, bool> BoolOptions;
THashMap<TString, int> IntOptions;
THashMap<TString, ui32> UIntOptions;
THashMap<TString, double> DoubleOptions;
THashMap<TString, TString> StringOptions;
THashMap<TString, TVector<double>> ListOfDoublesOptions;
NJson::TJsonValue BestParams;
public:
TBestOptionValuesWithCvResult()
: BestParams(NJson::JSON_MAP)
{}

void SetOptionsFromJson(
const THashMap<TString, NJson::TJsonValue>& options,
const TVector<TString>& optionsNames);
Expand Down
21 changes: 2 additions & 19 deletions catboost/python-package/catboost/_catboost.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1097,12 +1097,7 @@ cdef extern from "catboost/private/libs/hyperparameter_tuning/hyperparameter_tun

cdef cppclass TBestOptionValuesWithCvResult:
TVector[TCVResult] CvResult
THashMap[TString, bool_t] BoolOptions
THashMap[TString, int] IntOptions
THashMap[TString, ui32] UIntOptions
THashMap[TString, double] DoubleOptions
THashMap[TString, TString] StringOptions
THashMap[TString, TVector[double]] ListOfDoublesOptions
TJsonValue BestParams

cdef void GridSearch(
const TJsonValue& grid,
Expand Down Expand Up @@ -5265,19 +5260,7 @@ cdef class _CatBoost:
)
result_metrics.add(name)

best_params = {}
for key, value in results.BoolOptions:
best_params[to_native_str(key)] = value
for key, value in results.IntOptions:
best_params[to_native_str(key)] = value
for key, value in results.UIntOptions:
best_params[to_native_str(key)] = value
for key, value in results.DoubleOptions:
best_params[to_native_str(key)] = value
for key, value in results.StringOptions:
best_params[to_native_str(key)] = to_native_str(value)
for key, value in results.ListOfDoublesOptions:
best_params[to_native_str(key)] = [float(elem) for elem in value]
best_params = loads(to_native_str(WriteTJsonValue(results.BestParams)))
search_result = {}
search_result["params"] = best_params
if return_cv_results:
Expand Down
42 changes: 42 additions & 0 deletions catboost/python-package/ut/medium/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4180,6 +4180,48 @@ def test_grid_search_several_grids(task_type):
assert results['params']['border_count'] in grids[grid_num]['border_count']


def test_grid_search_complex_params(task_type):
n25_index = 25
n75_index = 75

train_labels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
prng = np.random.RandomState(seed=0)
pool = Pool(prng.random_sample(size=(1000, 100)), label=prng.choice(train_labels, size=1000))

params = {
'depth': [4, 7, 10],
'per_float_feature_quantization': [
None,
[f'{n25_index}:border_count=1024'],
[f'{n75_index}:border_count=1024'],
[f'{n25_index}:border_count=1024', f'{n75_index}:border_count=1024']
],
'l2_leaf_reg': [1, 3, 5, 7, 9],
'iterations': [10],
'learning_rate': [0.3],
'verbose': [100]
}

cbr = CatBoostRegressor(task_type=task_type, devices='0')
results = cbr.grid_search(
params,
pool,
cv=3,
partition_random_seed=42,
calc_cv_statistics=True,
search_by_train_test_split=True,
refit=True,
shuffle=True,
stratified=None,
train_size=0.8,
verbose=5,
plot=True
)

for name, values in params.items():
assert ((name in results['params']) and (results['params'][name] in values))


def test_feature_importance(task_type):
pool = Pool(TRAIN_FILE, column_description=CD_FILE)
pool_querywise = Pool(QUERYWISE_TRAIN_FILE, column_description=QUERYWISE_CD_FILE)
Expand Down

0 comments on commit 35b8928

Please sign in to comment.