From e2df4f919e9eaac3db0a864c5f9f1b3644d356fa Mon Sep 17 00:00:00 2001 From: Jarle Aase Date: Mon, 15 Apr 2024 15:09:35 +0300 Subject: [PATCH] Added support for transactions. Moved the actual 'exec' method to Handle, so that it can be used from a handle as well. Closes #14 --- include/mysqlpool/mysqlpool.h | 430 +++++++++++++++++++++++++--------- src/mysqlpool.cpp | 27 ++- tests/integration_tests.cpp | 55 ++++- 3 files changed, 391 insertions(+), 121 deletions(-) diff --git a/include/mysqlpool/mysqlpool.h b/include/mysqlpool/mysqlpool.h index 01d1e6b..e4b0dd8 100644 --- a/include/mysqlpool/mysqlpool.h +++ b/include/mysqlpool/mysqlpool.h @@ -143,6 +143,9 @@ using results = boost::mysql::results; constexpr auto tuple_awaitable = boost::asio::as_tuple(boost::asio::use_awaitable); struct Options { + + Options(bool reconnect = true) : reconnect_and_retry_query{reconnect} {} + bool reconnect_and_retry_query{true}; std::string time_zone; bool throw_on_empty_connection{false}; @@ -151,11 +154,22 @@ struct Options { bool close_prepared_statement{false}; }; +template +T getFirstArgument(T first, ArgsT...) { + return first; +} + class Mysqlpool { + enum class ErrorMode { + IGNORE, + RETRY, + ALWAYS_FAIL + }; + + public: using connection_t = boost::mysql::tcp_ssl_connection; - template static std::string logArgs(const T... args) { using namespace detail; @@ -209,7 +223,7 @@ class Mysqlpool { cache_.erase(makeHash(key)); } - sha256_hash_t makeHash(const std::span key) { + static sha256_hash_t makeHash(const std::span key) { return sha256(std::span(reinterpret_cast(key.data()), key.size())); } @@ -217,8 +231,7 @@ class Mysqlpool { struct ArrayHasher { std::size_t operator()(const sha256_hash_t& key) const { /* I might my wrong, but it makes no sense to me to re-hash a cryptographic hash */ - size_t val = *reinterpret_cast(key.data()); - return val; + return *reinterpret_cast(key.data()); } }; std::unordered_map cache_; @@ -233,9 +246,13 @@ class Mysqlpool { }; Connection(Mysqlpool& parent); - + Connection(const Connection&) = delete; + Connection(Connection&&) = delete; ~Connection(); + Connection& operator = (const Connection&) = delete; + Connection& operator = (Connection&&) = delete; + State state() const noexcept { return state_; } @@ -248,7 +265,7 @@ class Mysqlpool { void close() { setState(State::CLOSING); - connection_.async_close([this](boost::system::error_code ec) { + connection_.async_close([this](boost::system::error_code /*ec*/) { parent_.closed(*this); }); } @@ -291,7 +308,7 @@ class Mysqlpool { } // NB: not synchronized. Assumes safe access when it's not being changed. - const std::string timeZone() const { + std::string timeZone() const { return time_zone_name_; } @@ -301,9 +318,8 @@ class Mysqlpool { } // Cache for prepared statements (per connection) - boost::asio::awaitable> getStmt(boost::mysql::diagnostics& diag, - std::string_view query) { - + boost::asio::awaitable> + getStmt(boost::mysql::diagnostics& diag, std::string_view query) { auto& cached_stmt = stmt_cache_[query]; if (!cached_stmt.valid()) { logQuery("prepare-stmt", query); @@ -334,6 +350,94 @@ class Mysqlpool { class Handle { public: + /*! Transaction handle + * + * A transaction handle is a RAII object you will + * use to commit or rollback the transaction + * before the object goes out of scope! + * + * To remind you, there is an assert in the destructor that + * will fire if the transaction is not committed or rolled back + * in non-release builds. In release build, `std::terminate()` will be called. + * If a query fails with an exception, any open transaction is automatically + * rolled back by closing the database connection. In that case the + * transaction object shuld not be used, or just call rollback(). + * + * Once a transaction is committed or rolled back, the transaction + * object is empty and can not be used again, unless you re-assign + * a new transaction to it. + */ + class Transaction { + public: + Transaction(Handle& h, bool readOnly) + : handle_{&h}, read_only_{readOnly} + { + } + + Transaction(const Transaction&) = delete; + Transaction(Transaction&& t) noexcept + : handle_{t.handle_}, read_only_{t.read_only_} { + assert(handle_); + t.handle_ = {}; + }; + + /*! Commit the current transaction */ + boost::asio::awaitable commit() { + assert(handle_); + if (handle_) { + if (handle_->failed()) { + if (!read_only_) { + throw std::runtime_error{"The transaction cannot be committed as a query has already failed!"}; + } + } else { + co_await handle_->commit(); + } + } + handle_ = {}; + co_return; + } + + /*! Roll back the current transaction */ + boost::asio::awaitable rollback() { + assert(handle_); + if (handle_ && !handle_->failed()) { + co_await handle_->rollback(); + } + handle_ = {}; + co_return; + } + + ~Transaction() { + if (handle_ && handle_->failed()) { + MYSQLPOOL_LOG_DEBUG_("Handle failed. Transaction not committed or rolled back!)"); + return; + } + assert(!handle_); + if (handle_) { + std::cerr << "Transaction not committed or rolled back!" << std::endl; + std::terminate(); + } + } + + bool empty() const noexcept { + return handle_ == nullptr; + } + + Transaction& operator = (const Transaction&) = delete; + Transaction& operator = (Transaction&& t) noexcept { + assert(t.handle_); + handle_ = t.handle_; + read_only_ = t.read_only_; + t.handle_ = {}; + return *this; + } + + private: + Handle *handle_ = {}; + bool read_only_ = false; + }; + + Handle() = default; Handle(const Handle &) = delete; @@ -370,7 +474,7 @@ class Mysqlpool { } // Return the mysql connection - auto& connection() { + connection_t& connection() { assert(connection_); return connection_->connection_; } @@ -382,6 +486,8 @@ class Mysqlpool { void reset() { parent_ = {}; connection_ = {}; + has_transaction_ = false; + failed_ = false; } bool isClosed() const noexcept { @@ -396,127 +502,238 @@ class Mysqlpool { return connection_; } + bool hasConnection() const noexcept { + return connection_ != nullptr; + } + const auto& connectionWrapper() const noexcept { return connection_; } - Mysqlpool *parent_{}; - Connection *connection_{}; - boost::uuids::uuid uuid_; - - boost::asio::awaitable reconnect(); - }; + /*! Starts a transaction + * + * @param readOnly If true, the transaction is read-only + * @param reconnect If true, the connection will be reconnected + * if a non-recoverable error occurs during the TRANSACTION query. + * + * @return A transaction object that you will use to commit or rollback the transaction. + */ + [[nodiscard]] boost::asio::awaitable transaction(bool readOnly = false, bool reconnect = true) { + assert(connection_); + assert(!has_transaction_); + Options opts; + opts.reconnect_and_retry_query = reconnect; + co_await exec(readOnly ? "START TRANSACTION READ ONLY" : "START TRANSACTION", {}); + has_transaction_ = true; + co_return Transaction{*this, readOnly}; + } - [[nodiscard]] boost::asio::awaitable getConnection(const Options& opts = {}); + private: + friend class Transaction; + boost::asio::awaitable commit() { + assert(connection_); + assert(has_transaction_); + if (connection_ && has_transaction_) { + co_await exec("COMMIT", {false}); + } + has_transaction_ = false; + co_return; + } - template - boost::asio::awaitable exec(std::string_view query, const T& tuple) { + boost::asio::awaitable rollback() { + assert(connection_); + assert(has_transaction_); + if (connection_ && has_transaction_) { + co_await exec("ROLLBACK", {false}); + } + has_transaction_ = false; + co_return; + } - results res; - co_await std::apply([&](auto... args) -> boost::asio::awaitable { - res = co_await exec(query, args...); - }, tuple); + public: + template + [[nodiscard]] boost::asio::awaitable + exec(std::string_view query, const Options& opts, argsT ...args) { + results res; + boost::mysql::diagnostics diag; - co_return res; - } + try { + again: + // TODO: Revert the session time zone back to default if opts.locale_name is empty? + if (!opts.time_zone.empty() + && !connectionWrapper()->isSameTimeZone(opts.time_zone)) { - template - boost::asio::awaitable exec(std::string_view query, const Options& opts, const T& tuple) { + static const std::string_view ts_query = "SET time_zone=?"; + auto [sec, stmt] = co_await connectionWrapper()->getStmt(diag, ts_query); + if (!handleError(sec, diag, errorMode(opts))) { + co_await reconnect(); + goto again; + } + assert(stmt != nullptr); + logQuery("locale", ts_query, opts.time_zone); + auto [ec] = co_await connection().async_execute(stmt->bind(opts.time_zone), res, diag, tuple_awaitable); + if (!handleError(ec, diag, errorMode(opts))) { + co_await reconnect(); + goto again; + } + connectionWrapper()->setTimeZone(opts.time_zone); + } - results res; - co_await std::apply([&](auto... args) -> boost::asio::awaitable { - res = co_await exec(query, opts, args...); - }, tuple); + if constexpr (sizeof...(argsT) == 0) { + logQuery("static", query); + auto [ec] = co_await connection().async_execute(query, res, diag, tuple_awaitable); + if (!handleError(ec, diag, errorMode(opts))) { + co_await reconnect(); + goto again; + } + } else { + auto [sec, stmt] = co_await connectionWrapper()->getStmt(diag, query); + if (!handleError(sec, diag, errorMode(opts))) { + co_await reconnect(); + goto again; + } + assert(stmt != nullptr); + assert(stmt->valid()); + + boost::system::error_code ec; // error status for query execution + if constexpr (sizeof...(args) == 1 && FieldViewContainer) { + // Handle dynamic arguments as a range of field_view + logQuery("stmt-dynarg", query, args...); + + auto arg = getFirstArgument(args...); + auto [err] = co_await connection().async_execute(stmt->bind(arg.begin(), arg.end()), res, diag, tuple_awaitable); + ec = err; + } else { + logQuery("stmt", query, args...); + auto [err] = co_await connection().async_execute(stmt->bind(args...), res, diag, tuple_awaitable); + ec = err; + } - co_return res; - } + // Close the statement before we evaluate the query. The error handling for the + // query may throw an exception, and we need to close the statement before that. + if (opts.close_prepared_statement) { + // Close the statement (if any error occurs, we will just log it and continue + logQuery("close-stmt", query); + const auto [csec] = co_await connection().async_close_statement(*stmt, diag, tuple_awaitable); + if (sec) { + handleError(sec, diag, ErrorMode::IGNORE); + } + connectionWrapper()->stmtCache().erase(query); + } - template - T getFirstArgument(T first, Args... args) { - return first; - } + // Handle the query error if any + if (!handleError(ec, diag, errorMode(opts))) { + co_await reconnect(); + goto again; + } + } - template - boost::asio::awaitable exec(std::string_view query, const Options& opts, argsT ...args) { - auto conn = co_await getConnection(opts); - results res; - boost::mysql::diagnostics diag; - - again: - // TODO: Revert the session time zone back to default if opts.locale_name is empty? - if (!opts.time_zone.empty() - && !conn.connectionWrapper()->isSameTimeZone(opts.time_zone)) { - - static const std::string_view ts_query = "SET time_zone=?"; - auto [sec, stmt] = co_await conn.connectionWrapper()->getStmt(diag, ts_query); - if (!handleError(sec, diag)) { - co_await conn.reconnect(); - goto again; - } - assert(stmt != nullptr); - logQuery("locale", ts_query, opts.time_zone); - auto [ec] = co_await conn.connection().async_execute(stmt->bind(opts.time_zone), res, diag, tuple_awaitable); - if (!handleError(ec, diag)) { - co_await conn.reconnect(); - goto again; + co_return std::move(res); + } catch (const std::runtime_error& ex) { + failed_ = true; + throw; } - conn.connectionWrapper()->setTimeZone(opts.time_zone); } - if constexpr (sizeof...(argsT) == 0) { - logQuery("static", query); - auto [ec] = co_await conn.connection().async_execute(query, res, diag, tuple_awaitable); - if (!handleError(ec, diag)) { - co_await conn.reconnect(); - goto again; - } - } else { - auto [sec, stmt] = co_await conn.connectionWrapper()->getStmt(diag, query); - if (!handleError(sec, diag)) { - co_await conn.reconnect(); - goto again; - } - assert(stmt != nullptr); - assert(stmt->valid()); + template + [[nodiscard]] boost::asio::awaitable + exec(std::string_view query, argsT ...args) { + results res; + co_return co_await exec(query, Options{}, args...); + } + + template + [[nodiscard]] boost::asio::awaitable + exec(std::string_view query, const T& tuple) { + results res; + co_await std::apply([&](auto... args) -> boost::asio::awaitable { + res = co_await exec(query, Options{}, args...); + }, tuple); + + co_return res; + } + + template + [[nodiscard]] boost::asio::awaitable + exec(std::string_view query, const Options& opts, const T& tuple) { - boost::system::error_code ec; // error status for query execution - if constexpr (sizeof...(args) == 1 && FieldViewContainer) { - // Handle dynamic arguments as a range of field_view - logQuery("stmt-dynarg", query, args...); + results res; + co_await std::apply([&](auto... args) -> boost::asio::awaitable { + res = co_await exec(query, opts, args...); + }, tuple); - auto arg = getFirstArgument(args...); - auto [err] = co_await conn.connection().async_execute(stmt->bind(arg.begin(), arg.end()), res, diag, tuple_awaitable); - ec = err; + co_return res; + } + + boost::asio::awaitable reconnect(); + + void release() { + assert(connection_); + assert(!connection_->isAvailable()); + if (failed_) { + connection_->close(); } else { - logQuery("stmt", query, args...); - auto [err] = co_await conn.connection().async_execute(stmt->bind(args...), res, diag, tuple_awaitable); - ec = err; + connection_->touch(); } + connection_->release(); + } - // Close the statement before we evaluate the query. The error handling for the - // query may throw an exception, and we need to close the statement before that. - if (opts.close_prepared_statement) { - // Close the statement (if any error occurs, we will just log it and continue - logQuery("close-stmt", query); - const auto [csec] = co_await conn.connection().async_close_statement(*stmt, diag, tuple_awaitable); - if (sec) { - handleError(sec, diag, false /* just report any error */); - } - conn.connectionWrapper()->stmtCache().erase(query); - } + bool failed() const noexcept { + return failed_; + } - // Handle the query error if any - if (!handleError(ec, diag)) { - co_await conn.reconnect(); - goto again; + private: + ErrorMode errorMode(const Options& opts) const noexcept { + if (opts.reconnect_and_retry_query) { + assert(!has_transaction_); + return ErrorMode::RETRY; } + return ErrorMode::ALWAYS_FAIL; } - co_return std::move(res); + Mysqlpool *parent_{}; + Connection *connection_{}; + boost::uuids::uuid uuid_; + bool has_transaction_ = false; + bool failed_ = false; + }; + + [[nodiscard]] boost::asio::awaitable getConnection(const Options& opts = {}); + + // template + // [[nodiscard]] boost::asio::awaitable + // exec(std::string_view query, const T& tuple) { + + // results res; + // co_await std::apply([&](auto... args) -> boost::asio::awaitable { + // res = co_await exec(query, args...); + // }, tuple); + + // co_return res; + // } + + // template + // [[nodiscard]] boost::asio::awaitable + // exec(std::string_view query, const Options& opts, const T& tuple) { + + // results res; + // co_await std::apply([&](auto... args) -> boost::asio::awaitable { + // res = co_await exec(query, opts, args...); + // }, tuple); + + // co_return res; + // } + + template + [[nodiscard]] boost::asio::awaitable + exec(std::string_view query, const Options& opts, argsT ...args) { + auto conn = co_await getConnection(opts); + co_return co_await conn.exec(query, opts, args...); } template - boost::asio::awaitable exec(std::string_view query, argsT ...args) { + [[nodiscard]] boost::asio::awaitable + exec(std::string_view query, argsT ...args) { co_return co_await exec(query, Options{}, args...); } @@ -622,12 +839,13 @@ class Mysqlpool { } // If it returns false, connection to server is closed - bool handleError(const boost::system::error_code& ec, boost::mysql::diagnostics& diag, bool ignore = false); + static bool handleError(const boost::system::error_code& ec, + boost::mysql::diagnostics& diag, + ErrorMode mode = ErrorMode::ALWAYS_FAIL); void startTimer(); void onTimer(boost::system::error_code ec); - void release(Handle& h) noexcept; std::string dbUser() const; std::string dbPasswd() const; diff --git a/src/mysqlpool.cpp b/src/mysqlpool.cpp index 182ce4b..6a352ea 100644 --- a/src/mysqlpool.cpp +++ b/src/mysqlpool.cpp @@ -208,7 +208,7 @@ boost::asio::awaitable Mysqlpool::init() { co_return; } -bool Mysqlpool::handleError(const boost::system::error_code &ec, boost::mysql::diagnostics &diag, bool ignore) +bool Mysqlpool::handleError(const boost::system::error_code &ec, boost::mysql::diagnostics &diag, ErrorMode em) { if (ec) { MYSQLPOOL_LOG_DEBUG_("Statement failed with error: " @@ -216,7 +216,7 @@ bool Mysqlpool::handleError(const boost::system::error_code &ec, boost::mysql::d << "). Client: " << diag.client_message() << ". Server: " << diag.server_message()); - if (ignore) { + if (em == ErrorMode::IGNORE) { MYSQLPOOL_LOG_DEBUG_("Ignoring the error..."); return false; } @@ -230,9 +230,12 @@ bool Mysqlpool::handleError(const boost::system::error_code &ec, boost::mysql::d case boost::system::errc::connection_reset: case boost::system::errc::connection_aborted: case boost::asio::error::operation_aborted: - MYSQLPOOL_LOG_DEBUG_("The error is recoverable if we re-try the query it may succeed..."); - return false; // retry - + if (em == ErrorMode::RETRY) { + MYSQLPOOL_LOG_DEBUG_("The error is recoverable if we re-try the query it may succeed..."); + return false; // retry + } + MYSQLPOOL_LOG_DEBUG_("The error is recoverable but we will not re-try the query."); + ::boost::throw_exception(db_err{ec}, BOOST_CURRENT_LOCATION); default: MYSQLPOOL_LOG_DEBUG_("The error is non-recoverable"); ::boost::throw_exception(db_err{ec}, BOOST_CURRENT_LOCATION); @@ -259,7 +262,7 @@ void Mysqlpool::startTimer() }); } -void Mysqlpool::onTimer(boost::system::error_code ec) +void Mysqlpool::onTimer(boost::system::error_code /*ec*/) { if (closed_) { MYSQLPOOL_LOG_DEBUG_("Mysqlpool::onTimer() - We are closing down the connection pool."); @@ -267,7 +270,7 @@ void Mysqlpool::onTimer(boost::system::error_code ec) } MYSQLPOOL_LOG_TRACE_("onTimer()"); - std::scoped_lock lock{mutex_}; + const std::scoped_lock lock{mutex_}; const auto watermark = chrono::steady_clock::now(); for(auto& conn : connections_) { if (conn->isAvailable() && conn->expires() <= watermark) { @@ -281,13 +284,13 @@ void Mysqlpool::onTimer(boost::system::error_code ec) } void Mysqlpool::release(Handle &h) noexcept { - if (h.connection_) { + + if (h.hasConnection()) { MYSQLPOOL_LOG_TRACE_("DB Connection " << h.uuid() << " is being released from a handle."); - std::scoped_lock lock{mutex_}; - assert(!h.connection_->isAvailable()); - h.connection_->touch(); - h.connection_->release(); + const std::scoped_lock lock{mutex_}; + h.release(); } + boost::system::error_code ec; semaphore_.cancel_one(ec); } diff --git a/tests/integration_tests.cpp b/tests/integration_tests.cpp index 9ad32eb..47199a6 100644 --- a/tests/integration_tests.cpp +++ b/tests/integration_tests.cpp @@ -161,20 +161,20 @@ TEST(Functional, TimeZone) { auto test = [](Mysqlpool& db) -> boost::asio::awaitable { jgaa::mysqlpool::Options opts; - opts.locale_name = "UTC"; + opts.time_zone = "UTC"; auto res = co_await db.exec("SELECT @@session.time_zone"); EXPECT_TRUE(res.has_value() && !res.rows().empty()); if (res.has_value() && !res.rows().empty()) { const auto zone = res.rows().front().at(0).as_string(); - EXPECT_NE(zone, opts.locale_name); + EXPECT_NE(zone, opts.time_zone); } res = co_await db.exec("SELECT @@session.time_zone", opts); EXPECT_TRUE(res.has_value() && !res.rows().empty()); if (res.has_value() && !res.rows().empty()) { const auto zone = res.rows().front().at(0).as_string(); - EXPECT_EQ(zone, opts.locale_name); + EXPECT_EQ(zone, opts.time_zone); } co_return true; }; @@ -182,6 +182,55 @@ TEST(Functional, TimeZone) { run_async_test(test, false); } +TEST (Functional, TransactionRollback) { + + + auto test = [](Mysqlpool& db) -> boost::asio::awaitable { + Options opts{false}; + auto handle = co_await db.getConnection({}); + + auto trx = co_await handle.transaction(); + + auto res = co_await handle.exec(R"(INSERT INTO mysqlpool (name) VALUES (?))", opts, "Bean"); + EXPECT_EQ(res.affected_rows(), 1); + + EXPECT_FALSE(trx.empty()); + co_await trx.rollback(); + EXPECT_TRUE(trx.empty()); + + res = co_await handle.exec(R"(SELECT COUNT(*) FROM mysqlpool WHERE name = ?)", opts, "Bean"); + EXPECT_EQ(res.rows().front().at(0).as_int64(), 0); + co_return true; + }; + + run_async_test(test, false); +} + +TEST (Functional, TransactionCommit) { + + auto test = [](Mysqlpool& db) -> boost::asio::awaitable { + + Options opts{false}; + + auto handle = co_await db.getConnection({}); + + auto trx = co_await handle.transaction(); + + auto res = co_await handle.exec(R"(INSERT INTO mysqlpool (name) VALUES (?))", opts, "Bean"); + EXPECT_EQ(res.affected_rows(), 1); + + EXPECT_FALSE(trx.empty()); + co_await trx.commit(); + EXPECT_TRUE(trx.empty()); + + res = co_await handle.exec(R"(SELECT COUNT(*) FROM mysqlpool WHERE name = ?)", opts, "Bean"); + EXPECT_EQ(res.rows().front().at(0).as_int64(), 1); + co_return true; + }; + + run_async_test(test, false); +} + int main( int argc, char * argv[] ) { MYSQLPOOL_TEST_LOGGING_SETUP("trace");