From 49432fc0744f4d2d5f02f66886f9b4087027cc76 Mon Sep 17 00:00:00 2001 From: Jih-Wei Liang Date: Wed, 18 May 2022 01:27:56 +0200 Subject: [PATCH] Add method SetAcceptEncoding for customized Accept-Encoding header (#683) --- cpr/CMakeLists.txt | 1 + cpr/accept_encoding.cpp | 41 +++++++++++++++++++++++++ cpr/session.cpp | 24 +++++++++++++-- include/CMakeLists.txt | 1 + include/cpr/accept_encoding.h | 36 ++++++++++++++++++++++ include/cpr/session.h | 5 +++ test/abstractServer.cpp | 2 +- test/httpServer.cpp | 18 +++++++++++ test/httpServer.hpp | 1 + test/session_tests.cpp | 58 +++++++++++++++++++++++++++++++++++ 10 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 cpr/accept_encoding.cpp create mode 100644 include/cpr/accept_encoding.h diff --git a/cpr/CMakeLists.txt b/cpr/CMakeLists.txt index ee0ac6c20..08ab426c3 100644 --- a/cpr/CMakeLists.txt +++ b/cpr/CMakeLists.txt @@ -1,6 +1,7 @@ cmake_minimum_required(VERSION 3.15) add_library(cpr + accept_encoding.cpp async.cpp auth.cpp bearer.cpp diff --git a/cpr/accept_encoding.cpp b/cpr/accept_encoding.cpp new file mode 100644 index 000000000..4cf99abbe --- /dev/null +++ b/cpr/accept_encoding.cpp @@ -0,0 +1,41 @@ +#include "cpr/accept_encoding.h" + +namespace cpr { + +void mapMethodsToString(const std::initializer_list& from, std::vector& to) { + const auto* first = from.begin(); + const auto* last = from.end(); + auto output = std::back_inserter(to); + while (first != last) { + *output++ = AcceptEncodingMethodsStringMap.at(*first); + first = std::next(first); + } +} + +std::string concatenateMethodsWithComma(const std::vector& methods) { + auto first = std::next(methods.begin()); + auto last = methods.end(); + std::string init = methods[0]; + + for (; first != last; ++first) { + init = std::move(init) + ", " + *first; + } + return init; +} + +AcceptEncoding::AcceptEncoding(const std::initializer_list& methods) { + methods_.clear(); + mapMethodsToString(methods, methods_); +} + +AcceptEncoding::AcceptEncoding(const std::initializer_list& string_methods) : methods_{string_methods} {} + +bool AcceptEncoding::empty() const noexcept { + return methods_.empty(); +} + +const std::string AcceptEncoding::getString() const { + return concatenateMethodsWithComma(methods_); +} + +} // namespace cpr diff --git a/cpr/session.cpp b/cpr/session.cpp index 607985c48..183c93c00 100644 --- a/cpr/session.cpp +++ b/cpr/session.cpp @@ -72,6 +72,8 @@ class Session::Impl { void SetRange(const Range& range); void SetMultiRange(const MultiRange& multi_range); void SetReserveSize(const ReserveSize& reserve_size); + void SetAcceptEncoding(AcceptEncoding&& accept_encoding); + void SetAcceptEncoding(const AcceptEncoding& accept_encoding); cpr_off_t GetDownloadFileLength(); void ResponseStringReserve(size_t size); @@ -111,6 +113,7 @@ class Session::Impl { Proxies proxies_; ProxyAuthentication proxyAuth_; Header header_; + AcceptEncoding acceptEncoding_; /** * Will be set by the read callback. * Ensures that the "Transfer-Encoding" is set to "chunked", if not overriden in header_. @@ -588,6 +591,14 @@ void Session::Impl::SetReserveSize(const ReserveSize& reserve_size) { ResponseStringReserve(reserve_size.size); } +void Session::Impl::SetAcceptEncoding(const AcceptEncoding& accept_encoding) { + acceptEncoding_ = accept_encoding; +} + +void Session::Impl::SetAcceptEncoding(AcceptEncoding&& accept_encoding) { + acceptEncoding_ = std::move(accept_encoding); +} + void Session::Impl::PrepareDelete() { curl_easy_setopt(curl_->handle, CURLOPT_HTTPGET, 0L); curl_easy_setopt(curl_->handle, CURLOPT_NOBODY, 0L); @@ -818,8 +829,13 @@ void Session::Impl::prepareCommon() { #if LIBCURL_VERSION_MAJOR >= 7 #if LIBCURL_VERSION_MINOR >= 21 - /* enable all supported built-in compressions */ - curl_easy_setopt(curl_->handle, CURLOPT_ACCEPT_ENCODING, ""); + if (acceptEncoding_.empty()) { + /* enable all supported built-in compressions */ + curl_easy_setopt(curl_->handle, CURLOPT_ACCEPT_ENCODING, ""); + } + else { + curl_easy_setopt(curl_->handle, CURLOPT_ACCEPT_ENCODING, acceptEncoding_.getString().c_str()); + } #endif #endif @@ -922,6 +938,8 @@ void Session::SetHttpVersion(const HttpVersion& version) { pimpl_->SetHttpVersio void Session::SetRange(const Range& range) { pimpl_->SetRange(range); } void Session::SetMultiRange(const MultiRange& multi_range) { pimpl_->SetMultiRange(multi_range); } void Session::SetReserveSize(const ReserveSize& reserve_size) { pimpl_->SetReserveSize(reserve_size); } +void Session::SetAcceptEncoding(const AcceptEncoding& accept_encoding) { pimpl_->SetAcceptEncoding(accept_encoding); } +void Session::SetAcceptEncoding(AcceptEncoding&& accept_encoding) { pimpl_->SetAcceptEncoding(std::move(accept_encoding)); } void Session::SetOption(const ReadCallback& read) { pimpl_->SetReadCallback(read); } void Session::SetOption(const HeaderCallback& header) { pimpl_->SetHeaderCallback(header); } void Session::SetOption(const WriteCallback& write) { pimpl_->SetWriteCallback(write); } @@ -965,6 +983,8 @@ void Session::SetOption(const HttpVersion& version) { pimpl_->SetHttpVersion(ver void Session::SetOption(const Range& range) { pimpl_->SetRange(range); } void Session::SetOption(const MultiRange& multi_range) { pimpl_->SetMultiRange(multi_range); } void Session::SetOption(const ReserveSize& reserve_size) { pimpl_->SetReserveSize(reserve_size.size); } +void Session::SetOption(const AcceptEncoding& accept_encoding) { pimpl_->SetAcceptEncoding(accept_encoding); } +void Session::SetOption(AcceptEncoding&& accept_encoding) { pimpl_->SetAcceptEncoding(std::move(accept_encoding)); } cpr_off_t Session::GetDownloadFileLength() { return pimpl_->GetDownloadFileLength(); } void Session::ResponseStringReserve(size_t size) { pimpl_->ResponseStringReserve(size); } diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index 7d8656f6a..a54048f51 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -7,6 +7,7 @@ target_include_directories(cpr PUBLIC target_sources(cpr PRIVATE # Header files (useful in IDEs) + cpr/accept_encoding.h cpr/api.h cpr/async.h cpr/auth.h diff --git a/include/cpr/accept_encoding.h b/include/cpr/accept_encoding.h new file mode 100644 index 000000000..e09ad06bd --- /dev/null +++ b/include/cpr/accept_encoding.h @@ -0,0 +1,36 @@ +#ifndef CPR_ACCEPT_ENCODING_H +#define CPR_ACCEPT_ENCODING_H + +#include +#include +#include +#include +#include + +namespace cpr { + +enum class AcceptEncodingMethods { + identity, + deflate, + zlib, + gzip, +}; + +static const std::map AcceptEncodingMethodsStringMap{{AcceptEncodingMethods::identity, "identity"}, {AcceptEncodingMethods::deflate, "deflate"}, {AcceptEncodingMethods::zlib, "zlib"}, {AcceptEncodingMethods::gzip, "gzip"}}; + +class AcceptEncoding { + public: + AcceptEncoding() = default; + AcceptEncoding(const std::initializer_list& methods); + AcceptEncoding(const std::initializer_list& methods); + + bool empty() const noexcept; + const std::string getString() const; + + private: + std::vector methods_; +}; + +} // namespace cpr + +#endif diff --git a/include/cpr/session.h b/include/cpr/session.h index e8917def4..f6e216c0b 100644 --- a/include/cpr/session.h +++ b/include/cpr/session.h @@ -13,6 +13,7 @@ #include "cpr/cookies.h" #include "cpr/cprtypes.h" #include "cpr/curlholder.h" +#include "cpr/accept_encoding.h" #include "cpr/http_version.h" #include "cpr/interface.h" #include "cpr/limit_rate.h" @@ -87,6 +88,8 @@ class Session { void SetRange(const Range& range); void SetMultiRange(const MultiRange& multi_range); void SetReserveSize(const ReserveSize& reserve_size); + void SetAcceptEncoding(const AcceptEncoding& accept_encoding); + void SetAcceptEncoding(AcceptEncoding&& accept_encoding); // Used in templated functions void SetOption(const Url& url); @@ -132,6 +135,8 @@ class Session { void SetOption(const Range& range); void SetOption(const MultiRange& multi_range); void SetOption(const ReserveSize& reserve_size); + void SetOption(const AcceptEncoding& accept_encoding); + void SetOption(AcceptEncoding&& accept_encoding); cpr_off_t GetDownloadFileLength(); /** diff --git a/test/abstractServer.cpp b/test/abstractServer.cpp index 0ef50bb5d..2d63d5b0b 100644 --- a/test/abstractServer.cpp +++ b/test/abstractServer.cpp @@ -105,7 +105,7 @@ std::string AbstractServer::Base64Decode(const std::string& in) { break; } // NOLINTNEXTLINE (cppcoreguidelines-avoid-magic-numbers) - val = (val << 6) + T[c]; + val = (static_cast(val) << static_cast(6)) + T[c]; // NOLINTNEXTLINE (cppcoreguidelines-avoid-magic-numbers) valb += 6; if (valb >= 0) { diff --git a/test/httpServer.cpp b/test/httpServer.cpp index 9810916f5..327d47934 100644 --- a/test/httpServer.cpp +++ b/test/httpServer.cpp @@ -756,6 +756,22 @@ void HttpServer::OnRequestDownloadGzip(mg_connection* conn, http_message* msg) { } } +void HttpServer::OnRequestCheckAcceptEncoding(mg_connection* conn, http_message* msg) { + std::string response; + for (size_t i = 0; i < sizeof(msg->header_names) / sizeof(mg_str); i++) { + if (!msg->header_names[i].p) { + continue; + } + std::string name = std::string(msg->header_names[i].p, msg->header_names[i].len); + if (std::string{"Accept-Encoding"} == name) { + response = std::string(msg->header_values[i].p, msg->header_values[i].len); + } + } + std::string headers = "Content-Type: text/html"; + mg_send_head(conn, 200, response.length(), headers.c_str()); + mg_send(conn, response.c_str(), response.length()); +} + void HttpServer::OnRequest(mg_connection* conn, http_message* msg) { std::string uri = std::string(msg->uri.p, msg->uri.len); if (uri == "/") { @@ -822,6 +838,8 @@ void HttpServer::OnRequest(mg_connection* conn, http_message* msg) { OnRequestDownloadGzip(conn, msg); } else if (uri == "/local_port.html") { OnRequestLocalPort(conn, msg); + } else if (uri == "/check_accept_encoding.html") { + OnRequestCheckAcceptEncoding(conn, msg); } else { OnRequestNotFound(conn, msg); } diff --git a/test/httpServer.hpp b/test/httpServer.hpp index c24e65bc3..449174446 100644 --- a/test/httpServer.hpp +++ b/test/httpServer.hpp @@ -52,6 +52,7 @@ class HttpServer : public AbstractServer { static void OnRequestPatchNotAllowed(mg_connection* conn, http_message* msg); static void OnRequestDownloadGzip(mg_connection* conn, http_message* msg); static void OnRequestLocalPort(mg_connection* conn, http_message* msg); + static void OnRequestCheckAcceptEncoding(mg_connection* conn, http_message* msg); protected: mg_connection* initServer(mg_mgr* mgr, MG_CB(mg_event_handler_t event_handler, void* user_data)) override; diff --git a/test/session_tests.cpp b/test/session_tests.cpp index 80bcfddfd..675575d8d 100644 --- a/test/session_tests.cpp +++ b/test/session_tests.cpp @@ -1004,6 +1004,64 @@ TEST(BasicTests, ReserveResponseString) { EXPECT_EQ(ErrorCode::OK, response.error.code); } +TEST(BasicTests, AcceptEncodingTestWithMethodsStringMap) { + Url url{server->GetBaseUrl() + "/check_accept_encoding.html"}; + Session session; + session.SetUrl(url); + session.SetAcceptEncoding({{AcceptEncodingMethods::deflate, AcceptEncodingMethods::gzip, AcceptEncodingMethods::zlib}}); + Response response = session.Get(); + std::string expected_text{"deflate, gzip, zlib"}; + EXPECT_EQ(expected_text, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + EXPECT_EQ(ErrorCode::OK, response.error.code); +} + +TEST(BasicTests, AcceptEncodingTestWithMethodsStringMapLValue) { + Url url{server->GetBaseUrl() + "/check_accept_encoding.html"}; + Session session; + session.SetUrl(url); + AcceptEncoding accept_encoding{{AcceptEncodingMethods::deflate, AcceptEncodingMethods::gzip, AcceptEncodingMethods::zlib}}; + session.SetAcceptEncoding(accept_encoding); + Response response = session.Get(); + std::string expected_text{"deflate, gzip, zlib"}; + EXPECT_EQ(expected_text, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + EXPECT_EQ(ErrorCode::OK, response.error.code); +} + +TEST(BasicTests, AcceptEncodingTestWithCostomizedString) { + Url url{server->GetBaseUrl() + "/check_accept_encoding.html"}; + Session session; + session.SetUrl(url); + session.SetAcceptEncoding({{"deflate", "gzip", "zlib"}}); + Response response = session.Get(); + std::string expected_text{"deflate, gzip, zlib"}; + EXPECT_EQ(expected_text, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + EXPECT_EQ(ErrorCode::OK, response.error.code); +} + +TEST(BasicTests, AcceptEncodingTestWithCostomizedStringLValue) { + Url url{server->GetBaseUrl() + "/check_accept_encoding.html"}; + Session session; + session.SetUrl(url); + AcceptEncoding accept_encoding{{"deflate", "gzip", "zlib"}}; + session.SetAcceptEncoding(accept_encoding); + Response response = session.Get(); + std::string expected_text{"deflate, gzip, zlib"}; + EXPECT_EQ(expected_text, response.text); + EXPECT_EQ(url, response.url); + EXPECT_EQ(std::string{"text/html"}, response.header["content-type"]); + EXPECT_EQ(200, response.status_code); + EXPECT_EQ(ErrorCode::OK, response.error.code); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::testing::AddGlobalTestEnvironment(server);