Skip to content

Commit

Permalink
Merge pull request #21692 from UnaNancyOwen:add_softmax
Browse files Browse the repository at this point in the history
* add apply softmax option to ClassificationModel

* remove default arguments of ClassificationModel::setSoftMax()

* fix build for python

* fix docs warning for setSoftMax()

* add impl for ClassficationModel()

* fix failed build for docs by trailing whitespace

* move to implement classify() to ClassificationModel_Impl

* move to implement softmax() to ClassificationModel_Impl

* remove softmax from public method in ClassificationModel
  • Loading branch information
UnaNancyOwen committed Mar 7, 2022
1 parent 901e0dd commit 8db7d43
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 11 deletions.
21 changes: 21 additions & 0 deletions modules/dnn/include/opencv2/dnn/dnn.hpp
Expand Up @@ -1310,6 +1310,9 @@ CV__DNN_INLINE_NS_BEGIN
class CV_EXPORTS_W_SIMPLE ClassificationModel : public Model
{
public:
CV_DEPRECATED_EXTERNAL // avoid using in C++ code, will be moved to "protected" (need to fix bindings first)
ClassificationModel();

/**
* @brief Create classification model from network represented in one of the supported formats.
* An order of @p model and @p config arguments does not matter.
Expand All @@ -1324,6 +1327,24 @@ CV__DNN_INLINE_NS_BEGIN
*/
CV_WRAP ClassificationModel(const Net& network);

/**
* @brief Set enable/disable softmax post processing option.
*
* If this option is true, softmax is applied after forward inference within the classify() function
* to convert the confidences range to [0.0-1.0].
* This function allows you to toggle this behavior.
* Please turn true when not contain softmax layer in model.
* @param[in] enable Set enable softmax post processing within the classify() function.
*/
CV_WRAP ClassificationModel& setEnableSoftmaxPostProcessing(bool enable);

/**
* @brief Get enable/disable softmax post processing option.
*
* This option defaults to false, softmax post processing is not applied within the classify() function.
*/
CV_WRAP bool getEnableSoftmaxPostProcessing() const;

/** @brief Given the @p input frame, create input blob, run net and return top-1 prediction.
* @param[in] frame The input image.
*/
Expand Down
89 changes: 78 additions & 11 deletions modules/dnn/src/model.cpp
Expand Up @@ -197,28 +197,95 @@ void Model::predict(InputArray frame, OutputArrayOfArrays outs) const
}


class ClassificationModel_Impl : public Model::Impl
{
public:
virtual ~ClassificationModel_Impl() {}
ClassificationModel_Impl() : Impl() {}
ClassificationModel_Impl(const ClassificationModel_Impl&) = delete;
ClassificationModel_Impl(ClassificationModel_Impl&&) = delete;

void setEnableSoftmaxPostProcessing(bool enable)
{
applySoftmax = enable;
}

bool getEnableSoftmaxPostProcessing() const
{
return applySoftmax;
}

std::pair<int, float> classify(InputArray frame)
{
std::vector<Mat> outs;
processFrame(frame, outs);
CV_Assert(outs.size() == 1);

Mat out = outs[0].reshape(1, 1);

if(getEnableSoftmaxPostProcessing())
{
softmax(out, out);
}

double conf;
Point maxLoc;
cv::minMaxLoc(out, nullptr, &conf, nullptr, &maxLoc);
return {maxLoc.x, static_cast<float>(conf)};
}

protected:
void softmax(InputArray inblob, OutputArray outblob)
{
const Mat input = inblob.getMat();
outblob.create(inblob.size(), inblob.type());

Mat exp;
const float max = *std::max_element(input.begin<float>(), input.end<float>());
cv::exp((input - max), exp);
outblob.getMat() = exp / cv::sum(exp)[0];
}

protected:
bool applySoftmax = false;
};

ClassificationModel::ClassificationModel()
: Model()
{
// nothing
}

ClassificationModel::ClassificationModel(const String& model, const String& config)
: Model(model, config)
: ClassificationModel(readNet(model, config))
{
// nothing
}

ClassificationModel::ClassificationModel(const Net& network)
: Model(network)
: Model()
{
// nothing
impl = makePtr<ClassificationModel_Impl>();
impl->initNet(network);
}

std::pair<int, float> ClassificationModel::classify(InputArray frame)
ClassificationModel& ClassificationModel::setEnableSoftmaxPostProcessing(bool enable)
{
std::vector<Mat> outs;
impl->processFrame(frame, outs);
CV_Assert(outs.size() == 1);
CV_Assert(impl != nullptr && impl.dynamicCast<ClassificationModel_Impl>() != nullptr);
impl.dynamicCast<ClassificationModel_Impl>()->setEnableSoftmaxPostProcessing(enable);
return *this;
}

double conf;
cv::Point maxLoc;
minMaxLoc(outs[0].reshape(1, 1), nullptr, &conf, nullptr, &maxLoc);
return {maxLoc.x, static_cast<float>(conf)};
bool ClassificationModel::getEnableSoftmaxPostProcessing() const
{
CV_Assert(impl != nullptr && impl.dynamicCast<ClassificationModel_Impl>() != nullptr);
return impl.dynamicCast<ClassificationModel_Impl>()->getEnableSoftmaxPostProcessing();
}

std::pair<int, float> ClassificationModel::classify(InputArray frame)
{
CV_Assert(impl != nullptr && impl.dynamicCast<ClassificationModel_Impl>() != nullptr);
return impl.dynamicCast<ClassificationModel_Impl>()->classify(frame);
}

void ClassificationModel::classify(InputArray frame, int& classId, float& conf)
Expand Down

0 comments on commit 8db7d43

Please sign in to comment.