Skip to content

Commit

Permalink
trainAuto: uses parallel_for_ to dispatch all parameters combination …
Browse files Browse the repository at this point in the history
…to test
  • Loading branch information
r2d3 committed Aug 7, 2017
1 parent 2f4a3e4 commit 8762947
Showing 1 changed file with 112 additions and 57 deletions.
169 changes: 112 additions & 57 deletions modules/ml/src/svm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,101 @@ class SVMImpl : public SVM
return true;
}

class TrainAutoBody : public ParallelLoopBody
{
public:
TrainAutoBody(const vector<SvmParams>& _parameters,
const cv::Mat& _samples,
const cv::Mat& _responses,
const cv::Mat& _labels,
const vector<int>& _sidx,
bool _is_classification,
int _k_fold,
std::vector<double>& _result) :
parameters(_parameters), samples(_samples), responses(_responses), labels(_labels),
sidx(_sidx), is_classification(_is_classification), k_fold(_k_fold), result(_result)
{}

void operator()( const cv::Range& range ) const
{
int sample_count = samples.rows;
int var_count_ = samples.cols;
size_t sample_size = var_count_*samples.elemSize();

int test_sample_count = (sample_count + k_fold/2)/k_fold;
int train_sample_count = sample_count - test_sample_count;

// Use a local instance
cv::Ptr<SVMImpl> svm = makePtr<SVMImpl>();
svm->class_labels = labels;

int rtype = responses.type();

Mat temp_train_samples(train_sample_count, var_count_, CV_32F);
Mat temp_test_samples(test_sample_count, var_count_, CV_32F);
Mat temp_train_responses(train_sample_count, 1, rtype);
Mat temp_test_responses;

int i,j,k,p;

for( p = range.start; p < range.end; p++ )
{
svm->setParams(parameters[p]);

double error = 0;
for( k = 0; k < k_fold; k++ )
{
int start = (k*sample_count + k_fold/2)/k_fold;
for( i = 0; i < train_sample_count; i++ )
{
j = sidx[(i+start)%sample_count];
memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
if( is_classification )
temp_train_responses.at<int>(i) = responses.at<int>(j);
else if( !responses.empty() )
temp_train_responses.at<float>(i) = responses.at<float>(j);
}

// Train SVM on <train_size> samples
if( !svm->do_train( temp_train_samples, temp_train_responses ))
continue;

for( i = 0; i < test_sample_count; i++ )
{
j = sidx[(i+start+train_sample_count) % sample_count];
memcpy(temp_test_samples.ptr(i), samples.ptr(j), sample_size);
}

svm->predict(temp_test_samples, temp_test_responses, 0);
for( i = 0; i < test_sample_count; i++ )
{
float val = temp_test_responses.at<float>(i);
j = sidx[(i+start+train_sample_count) % sample_count];
if( is_classification )
error += (float)(val != responses.at<int>(j));
else
{
val -= responses.at<float>(j);
error += val*val;
}
}
}

result[p] = error;
}
}

private:
const vector<SvmParams>& parameters;
const cv::Mat& samples;
const cv::Mat& responses;
const cv::Mat& labels;
const vector<int>& sidx;
bool is_classification;
int k_fold;
std::vector<double>& result;
};

bool trainAuto( const Ptr<TrainData>& data, int k_fold,
ParamGrid C_grid, ParamGrid gamma_grid, ParamGrid p_grid,
ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
Expand Down Expand Up @@ -1713,12 +1808,11 @@ class SVMImpl : public SVM

int sample_count = samples.rows;
var_count = samples.cols;
size_t sample_size = var_count*samples.elemSize();

vector<int> sidx;
setRangeVector(sidx, sample_count);

int i, j, k;
int i, k;

// randomly permute training samples
for( i = 0; i < sample_count; i++ )
Expand Down Expand Up @@ -1764,75 +1858,36 @@ class SVMImpl : public SVM
}
}

int test_sample_count = (sample_count + k_fold/2)/k_fold;
int train_sample_count = sample_count - test_sample_count;

SvmParams best_params = params;
double min_error = FLT_MAX;

int rtype = responses.type();

Mat temp_train_samples(train_sample_count, var_count, CV_32F);
Mat temp_test_samples(test_sample_count, var_count, CV_32F);
Mat temp_train_responses(train_sample_count, 1, rtype);
Mat temp_test_responses;

// If grid.minVal == grid.maxVal, this will allow one and only one pass through the loop with params.var = grid.minVal.
#define FOR_IN_GRID(var, grid) \
for( params.var = grid.minVal; params.var == grid.minVal || params.var < grid.maxVal; params.var = (grid.minVal == grid.maxVal) ? grid.maxVal + 1 : params.var * grid.logStep )

// Create the list of parameters to test
std::vector<SvmParams> parameters;
FOR_IN_GRID(C, C_grid)
FOR_IN_GRID(gamma, gamma_grid)
FOR_IN_GRID(p, p_grid)
FOR_IN_GRID(nu, nu_grid)
FOR_IN_GRID(coef0, coef_grid)
FOR_IN_GRID(degree, degree_grid)
{
// make sure we updated the kernel and other parameters
setParams(params);

double error = 0;
for( k = 0; k < k_fold; k++ )
{
int start = (k*sample_count + k_fold/2)/k_fold;
for( i = 0; i < train_sample_count; i++ )
{
j = sidx[(i+start)%sample_count];
memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
if( is_classification )
temp_train_responses.at<int>(i) = responses.at<int>(j);
else if( !responses.empty() )
temp_train_responses.at<float>(i) = responses.at<float>(j);
}

// Train SVM on <train_size> samples
if( !do_train( temp_train_samples, temp_train_responses ))
continue;
parameters.push_back(params);
}

for( i = 0; i < test_sample_count; i++ )
{
j = sidx[(i+start+train_sample_count) % sample_count];
memcpy(temp_test_samples.ptr(i), samples.ptr(j), sample_size);
}
std::vector<double> result(parameters.size());
TrainAutoBody invoker(parameters, samples, responses, class_labels, sidx,
is_classification, k_fold, result);
parallel_for_(cv::Range(0,(int)parameters.size()), invoker);

predict(temp_test_samples, temp_test_responses, 0);
for( i = 0; i < test_sample_count; i++ )
{
float val = temp_test_responses.at<float>(i);
j = sidx[(i+start+train_sample_count) % sample_count];
if( is_classification )
error += (float)(val != responses.at<int>(j));
else
{
val -= responses.at<float>(j);
error += val*val;
}
}
}
if( min_error > error )
// Extract the best parameters
SvmParams best_params = params;
double min_error = FLT_MAX;
for( i = 0; i < (int)result.size(); i++ )
{
if( result[i] < min_error )
{
min_error = error;
best_params = params;
min_error = result[i];
best_params = parameters[i];
}
}

Expand Down

0 comments on commit 8762947

Please sign in to comment.