Skip to content

Commit

Permalink
Merge pull request #651 from jolibrain/ssd_options
Browse files Browse the repository at this point in the history
Additional API control over useful SSD training / deploy parameters
  • Loading branch information
beniz committed Oct 17, 2019
2 parents de8f406 + 4e2ade7 commit 2da03ac
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion src/backends/caffe/caffelib.cc
Expand Up @@ -696,13 +696,47 @@ namespace dd
caffe::ReadProtoFromTextFile(dest_net,&net_param);
caffe::ReadProtoFromTextFile(dest_deploy_net,&deploy_net_param);

// get ssd detailed params, if any.
double ssd_expand_prob = -1.0;
double ssd_max_expand_ratio = -1.0;
std::string ssd_mining_type;
double ssd_neg_pos_ratio = -1.0;
double ssd_neg_overlap = -1.0;
int ssd_keep_top_k = -1;
if (ad.has("ssd_expand_prob"))
ssd_expand_prob = ad.get("ssd_expand_prob").get<double>();
if (ad.has("ssd_max_expand_ratio"))
ssd_max_expand_ratio = ad.get("ssd_max_expand_ratio").get<double>();
if (ad.has("ssd_mining_type"))
ssd_mining_type = ad.get("ssd_mining_type").get<std::string>();
if (ad.has("ssd_neg_pos_ratio"))
ssd_neg_pos_ratio = ad.get("ssd_neg_pos_ratio").get<double>();
if (ad.has("ssd_neg_overlap"))
ssd_neg_pos_ratio = ad.get("ssd_neg_overlap").get<double>();
if (ad.has("ssd_keep_top_k"))
ssd_keep_top_k = ad.get("ssd_keep_top_k").get<int>();

//- if finetuning, change the proper layer names
std::string postfix = "_ftune";
const bool finetune = (ad.has("finetuning") && ad.get("finetuning").get<bool>());
int k = net_param.layer_size();
for (int l=0;l<k;l++)
{
caffe::LayerParameter *lparam = net_param.mutable_layer(l);

if (l == 0 && lparam->type() == "AnnotatedData")
{
caffe::TransformationParameter *trparam = lparam->mutable_transform_param();
if (ssd_expand_prob >= 0.0 || ssd_max_expand_ratio >= 0.0)
{
caffe::ExpansionParameter *exparam = trparam->mutable_expand_param();
if (ssd_expand_prob >= 0.0)
exparam->set_prob(ssd_expand_prob);
if (ssd_max_expand_ratio >= 0.0)
exparam->set_max_expand_ratio(ssd_max_expand_ratio);
}
}

if (finetune)
{
if (lparam->name().find("mbox_conf") != std::string::npos
Expand Down Expand Up @@ -730,11 +764,32 @@ namespace dd
//- set correct layer parameters based on nclasses
if (lparam->name() == "mbox_loss" || lparam->name() == "odm_loss")
{
lparam->mutable_multibox_loss_param()->set_num_classes(_nclasses);
lparam->mutable_multibox_loss_param()->set_num_classes(_nclasses);
if (ssd_mining_type == "MAX_NEGATIVE")
lparam->mutable_multibox_loss_param()->set_mining_type(caffe::MultiBoxLossParameter::MAX_NEGATIVE);
else if (ssd_mining_type == "HARD_EXAMPLE")
lparam->mutable_multibox_loss_param()->set_mining_type(caffe::MultiBoxLossParameter::HARD_EXAMPLE);
if (ssd_neg_pos_ratio >= 1.0)
lparam->mutable_multibox_loss_param()->set_neg_pos_ratio(ssd_neg_pos_ratio);
if (ssd_neg_overlap >= 0.0)
lparam->mutable_multibox_loss_param()->set_neg_overlap(ssd_neg_overlap);
}
else if (lparam->name() == "arm_loss")
{
if (ssd_mining_type == "MAX_NEGATIVE")
lparam->mutable_multibox_loss_param()->set_mining_type(caffe::MultiBoxLossParameter::MAX_NEGATIVE);
else if (ssd_mining_type == "HARD_EXAMPLE")
lparam->mutable_multibox_loss_param()->set_mining_type(caffe::MultiBoxLossParameter::HARD_EXAMPLE);
if (ssd_neg_pos_ratio >= 1.0)
lparam->mutable_multibox_loss_param()->set_neg_pos_ratio(ssd_neg_pos_ratio);
if (ssd_neg_overlap >= 0.0)
lparam->mutable_multibox_loss_param()->set_neg_overlap(ssd_neg_overlap);
}
else if (lparam->name() == "detection_out")
{
lparam->mutable_detection_output_param()->set_num_classes(_nclasses);
if (ssd_keep_top_k >= 1.0)
lparam->mutable_detection_output_param()->set_keep_top_k(ssd_keep_top_k);
}
else if (lparam->name() == "detection_eval")
{
Expand Down Expand Up @@ -809,6 +864,8 @@ namespace dd
else if (lparam->name() == "detection_out")
{
lparam->mutable_detection_output_param()->set_num_classes(_nclasses);
if (ssd_keep_top_k >= 1.0)
lparam->mutable_detection_output_param()->set_keep_top_k(ssd_keep_top_k);
}
else if (lparam->name().find("mbox_conf_reshape") != std::string::npos
|| lparam->name().find("odm_conf_reshape") != std::string::npos)
Expand Down

0 comments on commit 2da03ac

Please sign in to comment.