diff --git a/src/backends/torch/torchdataaug.cc b/src/backends/torch/torchdataaug.cc index c657ad360..2a231df88 100644 --- a/src/backends/torch/torchdataaug.cc +++ b/src/backends/torch/torchdataaug.cc @@ -395,17 +395,24 @@ namespace dd if (sample) { + int img_width = src.cols; + int img_height = src.rows; + std::uniform_int_distribution uniform_int_crop_x( + 0, img_width - cp._crop_size); + std::uniform_int_distribution uniform_int_crop_y( + 0, img_height - cp._crop_size); + #pragma omp critical { if (test) { - crop_x = cp._uniform_int_crop_x(_rnd_test_gen); - crop_y = cp._uniform_int_crop_y(_rnd_test_gen); + crop_x = uniform_int_crop_x(_rnd_test_gen); + crop_y = uniform_int_crop_y(_rnd_test_gen); } else { - crop_x = cp._uniform_int_crop_x(_rnd_gen); - crop_y = cp._uniform_int_crop_y(_rnd_gen); + crop_x = uniform_int_crop_x(_rnd_gen); + crop_y = uniform_int_crop_y(_rnd_gen); } } } @@ -464,20 +471,22 @@ namespace dd #pragma omp critical { + int img_width = src.cols; + int img_height = src.rows; // get shape and area to erase int w = 0, h = 0, rect_x = 0, rect_y = 0; if (cp._w == 0 && cp._h == 0) { - float s = cp._uniform_real_cutout_s(_rnd_gen) * cp._img_width - * cp._img_height; // area + float s = cp._uniform_real_cutout_s(_rnd_gen) * img_width + * img_height; // area float r = cp._uniform_real_cutout_r(_rnd_gen); // aspect ratio - w = std::min(cp._img_width, + w = std::min(img_width, static_cast(std::floor(std::sqrt(s / r)))); - h = std::min(cp._img_height, + h = std::min(img_height, static_cast(std::floor(std::sqrt(s * r)))); - std::uniform_int_distribution distx(0, cp._img_width - w); - std::uniform_int_distribution disty(0, cp._img_height - h); + std::uniform_int_distribution distx(0, img_width - w); + std::uniform_int_distribution disty(0, img_height - h); rect_x = distx(_rnd_gen); rect_y = disty(_rnd_gen); } diff --git a/src/backends/torch/torchdataaug.h b/src/backends/torch/torchdataaug.h index 453248c57..fceee521a 100644 --- a/src/backends/torch/torchdataaug.h +++ b/src/backends/torch/torchdataaug.h @@ -33,67 +33,34 @@ namespace dd { - class ImgAugParams + class CropParams { public: - ImgAugParams() : _img_width(224), _img_height(224) + CropParams() { } - ImgAugParams(const int &img_width, const int &img_height) - : _img_width(img_width), _img_height(img_height) + CropParams(const int &crop_size) : _crop_size(crop_size) { } - ~ImgAugParams() - { - } - - int _img_width; - int _img_height; - }; - - class CropParams : public ImgAugParams - { - public: - CropParams() : ImgAugParams() - { - } - - CropParams(const int &crop_size, const int &img_width, - const int &img_height) - : ImgAugParams(img_width, img_height), _crop_size(crop_size) - { - if (_crop_size > 0) - { - _uniform_int_crop_x - = std::uniform_int_distribution(0, _img_width - _crop_size); - _uniform_int_crop_y = std::uniform_int_distribution( - 0, _img_height - _crop_size); - } - } - ~CropParams() { } // default params int _crop_size = -1; - std::uniform_int_distribution _uniform_int_crop_x; - std::uniform_int_distribution _uniform_int_crop_y; int _test_crop_samples = 1; /**< number of sampled crops (at test time). */ }; - class CutoutParams : public ImgAugParams + class CutoutParams { public: - CutoutParams() : ImgAugParams() + CutoutParams() { } - CutoutParams(const float &prob, const int &img_width, - const int &img_height) - : ImgAugParams(img_width, img_height), _prob(prob) + CutoutParams(const float &prob) : _prob(prob) { _uniform_real_cutout_s = std::uniform_real_distribution(_cutout_sl, _cutout_sh); @@ -287,11 +254,6 @@ namespace dd _uniform_real_1(0.0, 1.0), _bernouilli(0.5), _uniform_int_rotate(0, 3) { - if (_crop_params._crop_size > 0) - { - _cutout_params._img_width = _crop_params._crop_size; - _cutout_params._img_height = _crop_params._crop_size; - } reset_rnd_test_gen(); } diff --git a/src/backends/torch/torchdataset.cc b/src/backends/torch/torchdataset.cc index d0ebc2bbb..5fae3546b 100644 --- a/src/backends/torch/torchdataset.cc +++ b/src/backends/torch/torchdataset.cc @@ -227,7 +227,7 @@ namespace dd torch::load(targett, targetstream); } - if (bgr.cols != width || bgr.rows != height) + if (width > 0 && height > 0 && (bgr.cols != width || bgr.rows != height)) { cv::resize(bgr, bgr, cv::Size(width, height), 0, 0, cv::INTER_CUBIC); @@ -860,10 +860,13 @@ namespace dd std::ifstream infile(bboxfname); std::string line; - double wfactor = static_cast(inputc->_width) - / static_cast(orig_width); - double hfactor = static_cast(inputc->_height) - / static_cast(orig_height); + double wfactor = inputc->_width > 0 ? static_cast(inputc->_width) + / static_cast(orig_width) + : 1; + double hfactor = inputc->_height > 0 + ? static_cast(inputc->_height) + / static_cast(orig_height) + : 1; while (std::getline(infile, line)) { diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index bccef07a6..294476a5b 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -701,8 +701,7 @@ namespace dd if (ad_mllib.has("crop_size")) { int crop_size = ad_mllib.get("crop_size").get(); - crop_params - = CropParams(crop_size, inputc.width(), inputc.height()); + crop_params = CropParams(crop_size); if (ad_mllib.has("test_crop_samples")) crop_params._test_crop_samples = ad_mllib.get("test_crop_samples").get(); @@ -712,8 +711,7 @@ namespace dd if (ad_mllib.has("cutout")) { float cutout = ad_mllib.get("cutout").get(); - cutout_params - = CutoutParams(cutout, inputc.width(), inputc.height()); + cutout_params = CutoutParams(cutout); this->_logger->info("cutout: {}", cutout); } GeometryParams geometry_params; @@ -1640,6 +1638,10 @@ namespace dd throw MLLibInternalException( "Couldn't find original image size for " + uri); } + int src_width + = inputc.width() > 0 ? inputc.width() : cols - 1; + int src_height + = inputc.height() > 0 ? inputc.height() : rows - 1; APIData results_ad; std::vector probs; @@ -1676,10 +1678,10 @@ namespace dd this->_mlmodel.get_hcorresp(labels_acc[j])); double bbox[] = { - bboxes_acc[j][0] / inputc.width() * (cols - 1), - bboxes_acc[j][1] / inputc.height() * (rows - 1), - bboxes_acc[j][2] / inputc.width() * (cols - 1), - bboxes_acc[j][3] / inputc.height() * (rows - 1), + bboxes_acc[j][0] / src_width * (cols - 1), + bboxes_acc[j][1] / src_height * (rows - 1), + bboxes_acc[j][2] / src_width * (cols - 1), + bboxes_acc[j][3] / src_height * (rows - 1), }; // clamp bbox diff --git a/src/imginputfileconn.h b/src/imginputfileconn.h index a86c8227f..4f93208e4 100644 --- a/src/imginputfileconn.h +++ b/src/imginputfileconn.h @@ -100,9 +100,9 @@ namespace dd { if (_scaled) scale(src, dst); - else if (_width == 0 || _height == 0) + else if (_width < 0 || _height < 0) { - if (_width == 0 && _height == 0) + if (_width < 0 && _height < 0) { // Do nothing and keep native resolution. May cause issues if // batched images are different resolutions @@ -199,9 +199,9 @@ namespace dd { if (_scaled) scale_cuda(src, dst); - else if (_width == 0 || _height == 0) + else if (_width < 0 || _height < 0) { - if (_width == 0 && _height == 0) + if (_width < 0 && _height < 0) { // Do nothing and keep native resolution. May cause issues if // batched images are different resolutions diff --git a/tests/ut-torchapi.cc b/tests/ut-torchapi.cc index bebfaf017..85f1615b7 100644 --- a/tests/ut-torchapi.cc +++ b/tests/ut-torchapi.cc @@ -394,6 +394,46 @@ TEST(torchapi, service_predict_object_detection) ASSERT_EQ(preds_best.Size(), 3); } +TEST(torchapi, service_predict_object_detection_any_size) +{ + JsonAPI japi; + std::string sname = "detectserv"; + std::string jstr + = "{\"mllib\":\"torch\",\"description\":\"fasterrcnn\",\"type\":" + "\"supervised\",\"model\":{\"repository\":\"" + + detect_repo + + "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":" + "-1,\"width\":-1,\"rgb\":true,\"scale\":0.0039},\"mllib\":{" + "\"template\":\"fasterrcnn\"}}}"; + + std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + ASSERT_EQ(created_str, joutstr); + std::string jpredictstr + = "{\"service\":\"detectserv\",\"parameters\":{" + "\"input\":{\"height\":-1," + "\"width\":-1},\"output\":{\"bbox\":true, " + "\"best_bbox\":1,\"confidence_threshold\":0.8}},\"data\":[\"" + + detect_train_repo_fasterrcnn + "/imgs/000550-L.jpg\"]}"; + + joutstr = japi.jrender(japi.service_predict(jpredictstr)); + JDoc jd; + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(200, jd["status"]["code"]); + ASSERT_TRUE(jd["body"]["predictions"].IsArray()); + + auto &preds = jd["body"]["predictions"][0]["classes"]; + std::string cl1 = preds[0]["cat"].GetString(); + ASSERT_TRUE(cl1 == "car"); + ASSERT_TRUE(preds[0]["prob"].GetDouble() > 0.9); + auto &bbox = preds[0]["bbox"]; + ASSERT_NEAR(bbox["xmin"].GetDouble(), 258.0, 5.0); + ASSERT_NEAR(bbox["ymin"].GetDouble(), 333.0, 5.0); + ASSERT_NEAR(bbox["xmax"].GetDouble(), 401.0, 5.0); + ASSERT_NEAR(bbox["ymax"].GetDouble(), 448.0, 5.0); +} + TEST(torchapi, service_predict_segmentation) { JsonAPI japi; @@ -2748,6 +2788,107 @@ TEST(torchapi, service_train_object_detection_translation) fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb"); } +TEST(torchapi, service_train_object_detection_yolox_any_size) +{ + // Test with arbitrary image size: width = -1, height = -1 + setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true); + torch::manual_seed(torch_seed); + at::globalContext().setDeterministicCuDNN(true); + + JsonAPI japi; + std::string sname = "detectserv"; + std::string jstr + = "{\"mllib\":\"torch\",\"description\":\"yolox\",\"type\":" + "\"supervised\",\"model\":{\"repository\":\"" + + detect_train_repo_yolox + + "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":" + "-1,\"width\":-1,\"rgb\":true,\"bbox\":true,\"db\":true}," + "\"mllib\":{\"template\":\"yolox\",\"gpu\":true," + "\"nclasses\":2}}}"; + + std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + ASSERT_EQ(created_str, joutstr); + + // Train + std::string jtrainstr + = "{\"service\":\"detectserv\",\"async\":false,\"parameters\":{" + "\"mllib\":{\"solver\":{\"iterations\":3" + + std::string("") + //+ iterations_detection + ",\"base_lr\":" + torch_lr + + ",\"iter_size\":2,\"solver_" + "type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":2," + "\"test_batch_size\":1,\"reg_weight\":0.5},\"resume\":false," + "\"mirror\":true,\"rotate\":true,\"crop_size\":512," + "\"test_crop_samples\":10," + "\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":" + "true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true," + "\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{" + "\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true," + "\"shuffle\":true},\"output\":{\"measure\":[\"map-05\",\"map-50\"," + "\"map-90\"]}},\"data\":[\"" + + fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}"; + + joutstr = japi.jrender(japi.service_train(jtrainstr)); + JDoc jd; + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(201, jd["status"]["code"]); + + // ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations"; + ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map"; + ASSERT_TRUE(jd["body"]["measure"]["map-05"].GetDouble() <= 1.0) << "map-05"; + ASSERT_TRUE(jd["body"]["measure"]["map-50"].GetDouble() <= 1.0) << "map-50"; + ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90"; + // ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map"; + + // check metrics + auto &meas = jd["body"]["measure"]; + ASSERT_TRUE(meas.HasMember("iou_loss")); + ASSERT_TRUE(meas.HasMember("conf_loss")); + ASSERT_TRUE(meas.HasMember("cls_loss")); + ASSERT_TRUE(meas.HasMember("l1_loss")); + ASSERT_TRUE(meas.HasMember("train_loss")); + ASSERT_TRUE( + std::abs(meas["train_loss"].GetDouble() + - (meas["iou_loss"].GetDouble() * 0.5 + + meas["cls_loss"].GetDouble() + meas["l1_loss"].GetDouble() + + meas["conf_loss"].GetDouble())) + < 0.0001); + + // check that predict works fine + std::string jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{" + "\"input\":{\"height\":-1," + "\"width\":-1},\"output\":{\"bbox\":true, " + "\"confidence_threshold\":0.8}},\"data\":[\"" + + detect_train_repo_fasterrcnn + + "/imgs/000550-L.jpg\"]}"; + joutstr = japi.jrender(japi.service_predict(jpredictstr)); + jd = JDoc(); + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(200, jd["status"]["code"]); + + std::unordered_set lfiles; + fileops::list_directory(detect_train_repo_yolox, true, false, false, lfiles); + for (std::string ff : lfiles) + { + if (ff.find("checkpoint") != std::string::npos + || ff.find("solver") != std::string::npos) + remove(ff.c_str()); + } + ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-" + + iterations_detection + ".ptw")); + ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-" + + iterations_detection + ".pt")); + + fileops::clear_directory(detect_train_repo_yolox + "train.lmdb"); + fileops::clear_directory(detect_train_repo_yolox + "test_0.lmdb"); + fileops::remove_dir(detect_train_repo_yolox + "train.lmdb"); + fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb"); +} + TEST(torchapi, service_train_images_native) { setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);