diff --git a/mocks/mock_session_base.hpp b/mocks/mock_session_base.hpp index 969932ab..19c027b3 100644 --- a/mocks/mock_session_base.hpp +++ b/mocks/mock_session_base.hpp @@ -33,7 +33,12 @@ namespace bzn { void()); MOCK_CONST_METHOD0(is_open, bool()); - MOCK_METHOD1(open, void(std::shared_ptr ws_factory)); - MOCK_METHOD1(accept, void(std::shared_ptr ws)); + MOCK_METHOD1(open, + void(std::shared_ptr ws_factory)); + MOCK_METHOD1(accept, + void(std::shared_ptr ws)); + MOCK_METHOD1(add_shutdown_handler, + void(bzn::session_shutdown_handler handler)); + }; } // namespace bzn diff --git a/node/node.cpp b/node/node.cpp index 0d69fb68..735da981 100644 --- a/node/node.cpp +++ b/node/node.cpp @@ -98,12 +98,12 @@ node::do_accept() , self->chaos , std::bind(&node::priv_protobuf_handler, self, std::placeholders::_1, std::placeholders::_2) , self->options->get_ws_idle_timeout() - , [](){} + , std::list{[](){}} , self->crypto); session->accept(std::move(ws)); - LOG(info) << "accepting new incomming connection with " << key; + LOG(info) << "accepting new incoming connection with " << key; // Do not attempt to identify the incoming session; one ip address could be running multiple daemons // and we can't identify them based on the outgoing ports they choose } @@ -163,7 +163,7 @@ node::find_session(const boost::asio::ip::tcp::endpoint& ep) , this->chaos , std::bind(&node::priv_protobuf_handler, shared_from_this(), std::placeholders::_1, std::placeholders::_2) , this->options->get_ws_idle_timeout() - , std::bind(&node::priv_session_shutdown_handler, shared_from_this(), key) + , std::list{std::bind(&node::priv_session_shutdown_handler, shared_from_this(), key)} , this->crypto); session->open(this->websocket); sessions.insert_or_assign(key, session); diff --git a/node/node.hpp b/node/node.hpp index 35b2d1ab..5758cdf3 100644 --- a/node/node.hpp +++ b/node/node.hpp @@ -30,7 +30,6 @@ namespace bzn { using ep_key_t = std::string; - using session_shutdown_handler = std::function; class node final : public bzn::node_base, public std::enable_shared_from_this { diff --git a/node/session.cpp b/node/session.cpp index be49fd4c..df5f478c 100644 --- a/node/session.cpp +++ b/node/session.cpp @@ -26,7 +26,7 @@ session::session( std::shared_ptr chaos, bzn::protobuf_handler proto_handler, std::chrono::milliseconds ws_idle_timeout, - bzn::session_shutdown_handler shutdown_handler, + std::list shutdown_handlers, std::shared_ptr crypto ) : session_id(session_id) @@ -34,7 +34,7 @@ session::session( , io_context(std::move(io_context)) , chaos(std::move(chaos)) , proto_handler(std::move(proto_handler)) - , shutdown_handler(std::move(shutdown_handler)) + , shutdown_handlers(std::move(shutdown_handlers)) , idle_timer(this->io_context->make_unique_steady_timer()) , ws_idle_timeout(std::move(ws_idle_timeout)) , write_buffer(nullptr, 0) @@ -127,6 +127,12 @@ session::accept(std::shared_ptr ws) ); } +void +session::add_shutdown_handler(const bzn::session_shutdown_handler handler) +{ + this->shutdown_handlers.push_back(handler); +} + void session::do_read() { @@ -281,7 +287,11 @@ session::close() this->closing = true; LOG(debug) << "closing session " << std::to_string(this->session_id); - this->io_context->post(this->shutdown_handler); + + for(const auto& handler : this->shutdown_handlers) + { + this->io_context->post(handler); + } if (this->websocket && this->websocket->is_open()) { diff --git a/node/session.hpp b/node/session.hpp index c438f9a0..158c66c6 100644 --- a/node/session.hpp +++ b/node/session.hpp @@ -40,7 +40,7 @@ namespace bzn std::shared_ptr chaos, bzn::protobuf_handler proto_handler, std::chrono::milliseconds ws_idle_timeout, - bzn::session_shutdown_handler shutdown_handler, + std::list shutdown_handlers, std::shared_ptr crypto); ~session(); @@ -57,6 +57,8 @@ namespace bzn void open(std::shared_ptr ws_factory) override; void accept(std::shared_ptr ws) override; + void add_shutdown_handler(const bzn::session_shutdown_handler handler) override; + private: void do_read(); void do_write(); @@ -73,7 +75,7 @@ namespace bzn std::list> write_queue; bzn::protobuf_handler proto_handler; - bzn::session_shutdown_handler shutdown_handler; + std::list shutdown_handlers; std::unique_ptr idle_timer; const std::chrono::milliseconds ws_idle_timeout; diff --git a/node/session_base.hpp b/node/session_base.hpp index 8bc0c6c6..06be926c 100644 --- a/node/session_base.hpp +++ b/node/session_base.hpp @@ -24,6 +24,7 @@ namespace bzn // forward declare... class session_base; + using session_shutdown_handler = std::function; using message_handler = std::function session)>; using protobuf_handler = std::function session)>; @@ -71,6 +72,12 @@ namespace bzn * Accept an incoming connection on some websocket */ virtual void accept(std::shared_ptr ws) = 0; + + /** + * Add additional shutdown handlers to the session + * @param handler + */ + virtual void add_shutdown_handler(bzn::session_shutdown_handler handler) = 0; }; } // bzn diff --git a/node/test/session_test.cpp b/node/test/session_test.cpp index e95c1f9f..68d8ae28 100644 --- a/node/test/session_test.cpp +++ b/node/test/session_test.cpp @@ -21,6 +21,8 @@ #include #include +#include + using namespace ::testing; namespace @@ -66,7 +68,7 @@ class session_test2 : public Test session_test2() { - session = std::make_shared(mock.io_context, 0, TEST_ENDPOINT, this->mock_chaos, [&](auto, auto){this->handler_called++;}, TEST_TIMEOUT, [](){}, nullptr); + session = std::make_shared(mock.io_context, 0, TEST_ENDPOINT, this->mock_chaos, [&](auto, auto){this->handler_called++;}, TEST_TIMEOUT, std::list{[](){}}, nullptr); } void yield() @@ -93,7 +95,7 @@ namespace bzn EXPECT_CALL(*mock_websocket_stream, async_read(_,_)); - auto session = std::make_shared(this->io_context, bzn::session_id(1), TEST_ENDPOINT, this->mock_chaos, [](auto, auto){}, TEST_TIMEOUT, [](){}, nullptr); + auto session = std::make_shared(this->io_context, bzn::session_id(1), TEST_ENDPOINT, this->mock_chaos, [](auto, auto){}, TEST_TIMEOUT, std::list{[](){}}, nullptr); session->accept(mock_websocket_stream); accept_handler(boost::system::error_code{}); @@ -146,7 +148,7 @@ namespace bzn bzn::smart_mock_io mock; mock.tcp_connect_works = false; - auto session = std::make_shared(mock.io_context, 0, TEST_ENDPOINT, this->mock_chaos, [](auto, auto){}, TEST_TIMEOUT, [](){}, nullptr); + auto session = std::make_shared(mock.io_context, 0, TEST_ENDPOINT, this->mock_chaos, [](auto, auto){}, TEST_TIMEOUT, std::list{[](){}}, nullptr); session->open(mock.websocket); this->yield(); @@ -156,4 +158,39 @@ namespace bzn mock.timer_callbacks.at(0)(boost::system::error_code{}); } + TEST_F(session_test2, additional_shutdown_handlers_can_be_added_to_session) + { + + bzn::smart_mock_io mock; + mock.tcp_connect_works = false; + + std::vector handler_counters { 0,0,0 }; + { + auto session = std::make_shared(mock.io_context + , 0, TEST_ENDPOINT, this->mock_chaos, [](auto, auto){}, TEST_TIMEOUT + , std::list{[&handler_counters]() { + ++handler_counters[0]; + }}, nullptr); + + session->add_shutdown_handler([&handler_counters](){++handler_counters[1];}); + session->add_shutdown_handler([&handler_counters](){++handler_counters[2];}); + + session->open(mock.websocket); + + this->yield(); + + // we are just testing that this doesn't cause a segfault + mock.timer_callbacks.at(0)(boost::system::error_code{}); + mock.timer_callbacks.at(0)(boost::system::error_code{}); + + } + this->yield(); + + // each shutdown handler must be called exactly once. + for(const auto handler_counter : handler_counters) + { + EXPECT_EQ(handler_counter, 1); + } + } + } // bzn diff --git a/pbft/pbft.cpp b/pbft/pbft.cpp index 11e2faac..6aeac37c 100644 --- a/pbft/pbft.cpp +++ b/pbft/pbft.cpp @@ -295,7 +295,7 @@ pbft::handle_request(const bzn_envelope& request_env, const std::shared_ptrsessions_waiting_on_forwarded_requests.find(hash) == this->sessions_waiting_on_forwarded_requests.end()) { - this->sessions_waiting_on_forwarded_requests[hash] = session; + this->add_session_to_pool(hash, session); } } @@ -444,7 +444,8 @@ pbft::handle_join_or_leave(const bzn_envelope& env, const pbft_membership_msg& m return; } - this->sessions_waiting_on_forwarded_requests[msg_hash] = session; + this->add_session_to_pool(msg_hash, session); + if (!this->is_primary()) { this->forward_request_to_primary(env); @@ -1899,3 +1900,20 @@ uint32_t pbft::generate_random_number(uint32_t min, uint32_t max) std::uniform_int_distribution dist(min, max); return dist(gen); } + +void pbft::add_session_to_pool(const std::string& msg_hash, std::shared_ptr session) +{ + if (session) + { + this->sessions_waiting_on_forwarded_requests[msg_hash] = session; + session->add_shutdown_handler([msg_hash, this]() + { + std::lock_guard lock(this->pbft_lock); + auto it = this->sessions_waiting_on_forwarded_requests.find(msg_hash); + if (it != this->sessions_waiting_on_forwarded_requests.end() && !it->second->is_open()) + { + this->sessions_waiting_on_forwarded_requests.erase(it); + } + }); + } +} diff --git a/pbft/pbft.hpp b/pbft/pbft.hpp index 9d01339c..2c18e931 100644 --- a/pbft/pbft.hpp +++ b/pbft/pbft.hpp @@ -50,6 +50,7 @@ namespace bzn { // fwd declare test as it's not in the same namespace... class pbft_test_database_response_is_forwarded_to_session_Test; + class pbft_test_add_session_to_pool_can_add_a_session_and_shutdown_handler_removes_session_from_pool_Test; } using request_hash_t = std::string; @@ -216,6 +217,8 @@ namespace bzn bool is_peer(const bzn::uuid_t& peer) const; bool get_sequences_and_request_hashes_from_proofs( const pbft_msg& viewchange_msg, std::set>& sequence_request_pairs) const; + void add_session_to_pool(const std::string& msg_hash, std::shared_ptr session); + // Using 1 as first value here to distinguish from default value of 0 in protobuf uint64_t view = 1; uint64_t next_issued_sequence_number = 1; @@ -296,6 +299,7 @@ namespace bzn FRIEND_TEST(pbft_newview_test, get_sequences_and_request_hashes_from_proofs); FRIEND_TEST(pbft_newview_test, test_last_sequence_in_newview_prepared_proofs); FRIEND_TEST(bzn::test::pbft_test, database_response_is_forwarded_to_session); + FRIEND_TEST(bzn::test::pbft_test, add_session_to_pool_can_add_a_session_and_shutdown_handler_removes_session_from_pool); friend class pbft_proto_test; friend class pbft_join_leave_test; diff --git a/pbft/test/pbft_test.cpp b/pbft/test/pbft_test.cpp index 307788eb..9e699b08 100644 --- a/pbft/test/pbft_test.cpp +++ b/pbft/test/pbft_test.cpp @@ -200,6 +200,34 @@ namespace bzn::test this->database_response_handler(this->request_msg, mock_session); } + TEST_F(pbft_test, add_session_to_pool_can_add_a_session_and_shutdown_handler_removes_session_from_pool) + { + this->build_pbft(); + + EXPECT_EQ(size_t(0), this->pbft->sessions_waiting_on_forwarded_requests.size()); + + bzn::session_shutdown_handler shutdown_handler{0}; + + EXPECT_CALL(*mock_session, add_shutdown_handler(_)) + .Times(Exactly(1)) + .WillRepeatedly(Invoke([&](auto handler) { + shutdown_handler = handler; + })); + + EXPECT_CALL(*mock_session, is_open()) + .WillOnce(Return(false)); + + pbft->handle_database_message(this->request_msg, this->mock_session); + + EXPECT_EQ(size_t(1), this->pbft->sessions_waiting_on_forwarded_requests.size()); + + EXPECT_TRUE(shutdown_handler != nullptr); + + shutdown_handler(); + + EXPECT_EQ(size_t(0), this->pbft->sessions_waiting_on_forwarded_requests.size()); + } + TEST_F(pbft_test, client_request_executed_results_in_message_response) {