Skip to content
This repository has been archived by the owner on Mar 3, 2020. It is now read-only.

Commit

Permalink
KEP-1239: Several safety fixes in session
Browse files Browse the repository at this point in the history
  • Loading branch information
isabelsavannah committed Mar 14, 2019
1 parent 9b9d1df commit 6ad52e9
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 101 deletions.
7 changes: 7 additions & 0 deletions include/boost_asio_beast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ namespace bzn::asio

virtual bzn::asio::close_handler wrap(close_handler handler) = 0;

virtual bzn::asio::task wrap(bzn::asio::task action) = 0;

virtual boost::asio::io_context::strand& get_strand() = 0;
};

Expand Down Expand Up @@ -252,6 +254,11 @@ namespace bzn::asio
return this->s.wrap(std::move(handler));
}

bzn::asio::task wrap(bzn::asio::task action) override
{
return this->s.wrap(std::move(action));
}

boost::asio::io_context::strand& get_strand() override
{
return this->s;
Expand Down
2 changes: 2 additions & 0 deletions mocks/mock_boost_asio_beast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ namespace bzn::asio {
bzn::asio::write_handler(write_handler handler));
MOCK_METHOD1(wrap,
bzn::asio::close_handler(close_handler handler));
MOCK_METHOD1(wrap,
bzn::asio::task(bzn::asio::task handler));
MOCK_METHOD0(get_strand,
boost::asio::io_context::strand&());
};
Expand Down
200 changes: 105 additions & 95 deletions node/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <node/session.hpp>
#include <node/node.hpp>
#include <sstream>
#include <boost/beast/websocket/error.hpp>


using namespace bzn;
Expand All @@ -38,6 +39,7 @@ session::session(
, shutdown_handlers(std::move(shutdown_handlers))
, idle_timer(this->io_context->make_unique_steady_timer())
, ws_idle_timeout(std::move(ws_idle_timeout))
, strand(this->io_context->make_unique_strand())
, write_buffer(nullptr, 0)
, crypto(std::move(crypto))
, monitor(std::move(monitor))
Expand All @@ -52,115 +54,125 @@ session::start_idle_timeout()

this->idle_timer->expires_from_now(this->ws_idle_timeout);
this->idle_timer->async_wait(
[self = shared_from_this()](auto /*ec*/)
this->strand->wrap([self = shared_from_this()](auto /*ec*/)
{
if (!self->activity)
{
LOG(info) << "Closing session " << std::to_string(self->session_id) << " due to inactivity";
self->close();
self->private_close();
return;
}

self->start_idle_timeout();
});
if (!self->closing)
{
self->start_idle_timeout();
}
}));
}

void
session::open(std::shared_ptr<bzn::beast::websocket_base> ws_factory)
{
this->start_idle_timeout();

std::shared_ptr<bzn::asio::tcp_socket_base> socket = this->io_context->make_unique_tcp_socket();
socket->async_connect(this->ep,
[self = shared_from_this(), socket, ws_factory](const boost::system::error_code& ec)
{
self->activity = true;
this->strand->wrap([self = shared_from_this(), ws_factory]()
{

if (ec)
std::shared_ptr<bzn::asio::tcp_socket_base> socket = self->io_context->make_unique_tcp_socket();
socket->async_connect(self->ep,
self->strand->wrap([self, socket, ws_factory](const boost::system::error_code& ec)
{
LOG(error) << "failed to connect to: " << self->ep.address().to_string() << ":" << self->ep.port() << " - " << ec.message();

return;
}

// we've completed the handshake...

std::lock_guard<std::mutex> lock(self->socket_lock);
self->activity = true;

// set tcp_nodelay option
boost::system::error_code option_ec;
socket->get_tcp_socket().set_option(boost::asio::ip::tcp::no_delay(true), option_ec);
if (option_ec)
{
LOG(error) << "failed to set socket option: " << option_ec.message();
}
if(ec)
{
LOG(error) << "failed to connect to: " << self->ep.address().to_string() << ":" << self->ep.port() << " - " << ec.message();
return;
}

self->websocket = ws_factory->make_unique_websocket_stream(socket->get_tcp_socket());
self->websocket->async_handshake(self->ep.address().to_string(), "/",
[self, ws_factory](const boost::system::error_code& ec)
// we've completed the handshake...

// set tcp_nodelay option
boost::system::error_code option_ec;
socket->get_tcp_socket().set_option(boost::asio::ip::tcp::no_delay(true), option_ec);
if (option_ec)
{
self->activity = true;
LOG(error) << "failed to set socket option: " << option_ec.message();
}

if (ec)
self->websocket = ws_factory->make_unique_websocket_stream(socket->get_tcp_socket());
self->websocket->async_handshake(self->ep.address().to_string(), "/",
self->strand->wrap([self, ws_factory](const boost::system::error_code& ec)
{
LOG(error) << "handshake failed: " << ec.message();
self->activity = true;

if (ec)
{
LOG(error) << "handshake failed: " << ec.message();

return;
}
return;
}

self->monitor->send_counter(statistic::session_opened);
self->do_read();
self->do_write();
});
});
self->start_idle_timeout();
self->do_read();
self->do_write();
}));
}));
})();
}

void
session::accept(std::shared_ptr<bzn::beast::websocket_stream_base> ws)
{
this->start_idle_timeout();

std::lock_guard<std::mutex> lock(this->socket_lock);
this->websocket = std::move(ws);
this->websocket->async_accept(
[self = shared_from_this()](boost::system::error_code ec)
{
self->activity = true;

if (ec)
{
LOG(error) << "websocket accept failed: " << ec.message();
return;
}
this->strand->wrap([self = shared_from_this(), ws]()
{

self->monitor->send_counter(statistic::session_opened);
self->do_read();
self->do_write();
}
);
self->websocket = std::move(ws);
self->websocket->async_accept(
self->strand->wrap(
[self](boost::system::error_code ec)
{
self->activity = true;

if (ec)
{
LOG(error) << "websocket accept failed: " << ec.message();
return;
}

self->monitor->send_counter(statistic::session_opened);
self->start_idle_timeout();
self->do_read();
self->do_write();
}
)
);
})();
}

void
session::add_shutdown_handler(const bzn::session_shutdown_handler handler)
{
this->shutdown_handlers.push_back(handler);
this->strand->wrap([handler, self = shared_from_this()]()
{
self->shutdown_handlers.push_back(handler);
})();
}

void
session::do_read()
{
// assume we are invoked inside the strand

auto buffer = std::make_shared<boost::beast::multi_buffer>();
std::lock_guard<std::mutex> lock(this->socket_lock);

if (this->reading || !this->is_open())
if (this->reading || !this->is_open() || this->closing)
{
return;
}

this->reading = true;

this->websocket->async_read(
*buffer, [self = shared_from_this(), buffer](boost::system::error_code ec, auto /*bytes_transferred*/)
this->websocket->async_read(*buffer,
this->strand->wrap([self = shared_from_this(), buffer](boost::system::error_code ec, auto /*bytes_transferred*/)
{
self->activity = true;

Expand All @@ -171,7 +183,10 @@ session::do_read()
{
LOG(error) << "websocket read failed: " << ec.message();
}
self->close();
if (ec != boost::beast::websocket::error::closed)
{
self->private_close();
}
return;
}

Expand All @@ -192,33 +207,29 @@ session::do_read()

self->reading = false;
self->do_read();
}
})
);
}

void
session::do_write()
{
// because of this mutex
std::lock_guard<std::mutex> lock(this->socket_lock);
// assume we are invoked inside the strand

// at most one concurrent invocation can pass this check
if(this->writing || !this->is_open() || this->write_queue.empty())
if(this->writing || !this->is_open() || this->write_queue.empty() || this->closing)
{
return;
}

// and set this flag
this->writing = true;

auto msg = this->write_queue.front();
this->write_queue.pop_front();

// so there will only be one instance of this callback
this->websocket->binary(true);
this->write_buffer = boost::asio::buffer(*msg);
this->websocket->async_write(this->write_buffer,
[self = shared_from_this(), msg](boost::system::error_code ec, auto bytes_transferred)
this->strand->wrap([self = shared_from_this(), msg](boost::system::error_code ec, auto bytes_transferred)
{
self->activity = true;

Expand All @@ -236,25 +247,17 @@ session::do_write()
LOG(error) << "websocket read failed: " << ec.message();
}

self->write_queue.push_front(msg);
if (ec != boost::beast::websocket::error::closed)
{
std::lock_guard<std::mutex> lock(self->socket_lock);
self->write_queue.push_front(msg);
self->private_close();
}

self->close();
return;
}

// and the flag will only be reset once after each sucessful write
self->writing = false;
/* multiple threads may race to perform the next do_write, but we don't care which wins. If there are no
* others then ours definitely works because we don't try until after resetting the flag. Our resetting
* the flag can't interfere with another do_write because no such do_write can happen until we reset the
* flag.
*/

self->do_write();
});
}));
}

void
Expand Down Expand Up @@ -284,21 +287,25 @@ session::send_message(std::shared_ptr<bzn::encoded_message> msg)
return;
}

this->strand->wrap([self = shared_from_this(), msg]()
{
std::lock_guard<std::mutex> lock(this->socket_lock);
this->write_queue.push_back(msg);
}

this->do_write();
self->write_queue.push_back(msg);
self->do_write();
})();
}

void
session::close()
{
// TODO: re-open socket later if we still have messages to send? (KEP-1037)
LOG(info) << "closing session";
this->strand->wrap([self = shared_from_this()](){self->close();});
}

std::lock_guard<std::mutex> lock(this->socket_lock);
void
session::private_close()
{
// assume we are invoked inside the strand

// TODO: re-open socket later if we still have messages to send? (KEP-1037)
if (this->closing)
{
return;
Expand All @@ -315,14 +322,16 @@ session::close()
if (this->websocket && this->websocket->is_open())
{
this->websocket->async_close(boost::beast::websocket::close_code::normal,
[self = shared_from_this()](auto ec)
this->strand->wrap([self = shared_from_this()](auto ec)
{
if (ec)
{
LOG(error) << "failed to close websocket: " << ec.message();
}
});
}));
}

this->idle_timer->cancel();
}

bool
Expand All @@ -338,3 +347,4 @@ session::~session()
LOG(warning) << "dropping session with " << this->write_queue.size() << " messages left in its write queue";
}
}

4 changes: 3 additions & 1 deletion node/session.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ namespace bzn

void start_idle_timeout();

void private_close();

const bzn::session_id session_id;
const boost::asio::ip::tcp::endpoint ep;

Expand All @@ -81,7 +83,7 @@ namespace bzn
std::unique_ptr<bzn::asio::steady_timer_base> idle_timer;
const std::chrono::milliseconds ws_idle_timeout;

std::mutex socket_lock;
std::unique_ptr<bzn::asio::strand_base> strand;

std::atomic<bool> writing = false;
std::atomic<bool> reading = false;
Expand Down

0 comments on commit 6ad52e9

Please sign in to comment.