Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor handshake #77

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
7 changes: 5 additions & 2 deletions .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
boost_version: ["1.75.0", "1.76.0", "1.77.0", "1.78.0", "1.79.0"]
boost_version: ["1.75.0", "1.76.0", "1.77.0", "1.78.0", "1.79.0", "1.80.0", "1.81.0"]
os: [windows-2019, windows-2022]
toolset: [v141, v142, v143, ClangCL]
build_type: [Debug, Release]
Expand All @@ -32,6 +32,8 @@ jobs:
- { boost_version: "1.77.0", toolset: v143 }
- { boost_version: "1.78.0", toolset: v141 }
- { boost_version: "1.79.0", toolset: v141 }
- { boost_version: "1.80.0", toolset: v141 }
- { boost_version: "1.81.0", toolset: v141 }
include:
- boost_version: "1.79.0"
os: windows-2022
Expand Down Expand Up @@ -62,8 +64,9 @@ jobs:
with:
fetch-depth: 0

# For Boost Versions >= 1.78, the toolset parameter has to be specified to install-boost.
- name: Add boost toolset to environment
if: contains(fromJson('["1.78.0", "1.79.0"]'), matrix.boost_version)
if: contains(fromJson('["1.78.0", "1.79.0", "1.80.0", "1.81.0"]'), matrix.boost_version)
run: echo BOOST_TOOLSET=$([[ "${{matrix.generator}}" == "MinGW Makefiles" ]] && echo "mingw" || echo "msvc") >> $GITHUB_ENV

# The platform_version passed to boost-install determines the msvc toolset version for which static libs are installed.
Expand Down
84 changes: 20 additions & 64 deletions include/boost/wintls/detail/async_handshake.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <boost/wintls/handshake_type.hpp>

#include <boost/wintls/detail/post_self.hpp>
#include <boost/wintls/detail/sspi_handshake.hpp>

#include <boost/asio/coroutine.hpp>
Expand All @@ -18,17 +19,16 @@ namespace boost {
namespace wintls {
namespace detail {

template <typename NextLayer>
template<typename NextLayer>
struct async_handshake : boost::asio::coroutine {
async_handshake(NextLayer& next_layer, detail::sspi_handshake& handshake, handshake_type type)
: next_layer_(next_layer)
, handshake_(handshake)
, entry_count_(0)
, state_(state::idle) {
: next_layer_(next_layer)
, handshake_(handshake)
, entry_count_(0) {
handshake_(type);
}

template <typename Self>
template<typename Self>
void operator()(Self& self, boost::system::error_code ec = {}, std::size_t length = 0) {
if (ec) {
self.complete(ec);
Expand All @@ -40,94 +40,50 @@ struct async_handshake : boost::asio::coroutine {
return entry_count_ > 1;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this supposed to return void ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a small lambda that returns bool.
Although i do not know why this is a lambda and not a simple bool variable. Technically, even the entry_count_ member could be a bool i suppose.
Then again, looking into the source of asio::detail::composed_op, it seems that it keeps track of its invocations itself and that we could access that via asio::asio_handler_is_continuation(&self). Is that how it should be done?

};

switch(state_) {
case state::reading:
handshake_.size_read(length);
state_ = state::idle;
break;
case state::writing:
handshake_.size_written(length);
state_ = state::idle;
break;
case state::idle:
break;
}

detail::sspi_handshake::state handshake_state;
sspi_handshake::state handshake_state;
BOOST_ASIO_CORO_REENTER(*this) {
while((handshake_state = handshake_()) != detail::sspi_handshake::state::done) {
if (handshake_state == detail::sspi_handshake::state::data_needed) {
while (true) {
handshake_state = handshake_();
if (handshake_state == sspi_handshake::state::data_needed) {
BOOST_ASIO_CORO_YIELD {
state_ = state::reading;
next_layer_.async_read_some(handshake_.in_buffer(), std::move(self));
}
handshake_.size_read(length);
continue;
}

if (handshake_state == detail::sspi_handshake::state::data_available) {
if (handshake_state == sspi_handshake::state::data_available) {
BOOST_ASIO_CORO_YIELD {
state_ = state::writing;
net::async_write(next_layer_, handshake_.out_buffer(), std::move(self));
}
handshake_.size_written(length);
continue;
}

if (handshake_state == detail::sspi_handshake::state::error) {
if (!is_continuation()) {
BOOST_ASIO_CORO_YIELD {
auto e = self.get_executor();
net::post(e, [self = std::move(self), ec, length]() mutable { self(ec, length); });
}
}
self.complete(handshake_.last_error());
return;
}

if (handshake_state == detail::sspi_handshake::state::done_with_data) {
BOOST_ASIO_CORO_YIELD {
state_ = state::writing;
net::async_write(next_layer_, handshake_.out_buffer(), std::move(self));
}
if (handshake_state == sspi_handshake::state::error) {
break;
}

if (handshake_state == detail::sspi_handshake::state::error_with_data) {
BOOST_ASIO_CORO_YIELD {
state_ = state::writing;
net::async_write(next_layer_, handshake_.out_buffer(), std::move(self));
}
if (!is_continuation()) {
BOOST_ASIO_CORO_YIELD {
auto e = self.get_executor();
net::post(e, [self = std::move(self), ec, length]() mutable { self(ec, length); });
}
}
self.complete(handshake_.last_error());
return;
if (handshake_state == sspi_handshake::state::done) {
BOOST_ASSERT(!handshake_.last_error());
handshake_.manual_auth();
break;
}
}

if (!is_continuation()) {
BOOST_ASIO_CORO_YIELD {
auto e = self.get_executor();
net::post(e, [self = std::move(self), ec, length]() mutable { self(ec, length); });
post_self(self, next_layer_, ec, length);
}
}
BOOST_ASSERT(!handshake_.last_error());
self.complete(handshake_.last_error());
}
}

private:
NextLayer& next_layer_;
detail::sspi_handshake& handshake_;
sspi_handshake& handshake_;
int entry_count_;
std::vector<char> input_;
enum class state {
idle,
reading,
writing
} state_;
};

} // namespace detail
Expand Down
4 changes: 2 additions & 2 deletions include/boost/wintls/detail/async_read.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifndef BOOST_WINTLS_DETAIL_ASYNC_READ_HPP
#define BOOST_WINTLS_DETAIL_ASYNC_READ_HPP

#include <boost/wintls/detail/post_self.hpp>
#include <boost/wintls/detail/sspi_decrypt.hpp>

#include <boost/asio/coroutine.hpp>
Expand Down Expand Up @@ -50,8 +51,7 @@ struct async_read : boost::asio::coroutine {
if (state == detail::sspi_decrypt::state::error) {
if (!is_continuation()) {
BOOST_ASIO_CORO_YIELD {
auto e = self.get_executor();
net::post(e, [self = std::move(self), ec, size_read]() mutable { self(ec, size_read); });
post_self(self, next_layer_, ec, size_read);
}
}
ec = decrypt_.last_error();
Expand Down
4 changes: 2 additions & 2 deletions include/boost/wintls/detail/async_shutdown.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifndef BOOST_WINTLS_DETAIL_ASYNC_SHUTDOWN_HPP
#define BOOST_WINTLS_DETAIL_ASYNC_SHUTDOWN_HPP

#include <boost/wintls/detail/post_self.hpp>
#include <boost/wintls/detail/sspi_shutdown.hpp>

#include <boost/asio/coroutine.hpp>
Expand Down Expand Up @@ -49,8 +50,7 @@ struct async_shutdown : boost::asio::coroutine {
} else {
if (!is_continuation()) {
BOOST_ASIO_CORO_YIELD {
auto e = self.get_executor();
net::post(e, [self = std::move(self), ec, size_written]() mutable { self(ec, size_written); });
post_self(self, next_layer_, ec, size_written);
}
}
self.complete(ec);
Expand Down
48 changes: 48 additions & 0 deletions include/boost/wintls/detail/post_self.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//
// Copyright (c) 2020 Kasper Laudrup (laudrup at stacktrace dot dk)
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
//

#ifndef BOOST_WINTLS_DETAIL_POST_SELF_HPP
#define BOOST_WINTLS_DETAIL_POST_SELF_HPP

#if BOOST_VERSION >= 108000
#include <boost/asio/append.hpp>
#endif
#include <boost/asio/post.hpp>
#include <boost/core/ignore_unused.hpp>
#include <boost/version.hpp>

namespace boost {
namespace wintls {
namespace detail {

// If a composed asynchronous operation completes immediately (due to an error)
// we do not want to call self.complete() directly as this may produce an infinite recursion in some cases.
// Instead, we post the intermediate completion handler (self) once.
// To achieve consistent behavior to non-erroneous cases, we post to the executor of the I/O object.
// Note that this only got accessible through self by get_io_executor since boost 1.81.
template<typename Self, typename IoObject, typename... Args>
auto post_self(Self& self, IoObject& io_object, boost::system::error_code ec, std::size_t length) {
#if BOOST_VERSION >= 108100
boost::ignore_unused(io_object);
auto ex = self.get_io_executor();
return boost::asio::post(ex, boost::asio::append(std::move(self), ec, length));
#elif BOOST_VERSION >= 108000
return boost::asio::post(io_object.get_executor(), boost::asio::append(std::move(self), ec, length));
#else
auto ex = io_object.get_executor();
// If the completion token associated with self had an associated executor,
// allocator or cancellation slot, we loose these here.
// Therefore, above solutions are better!
return boost::asio::post(ex, [self = std::move(self), ec, length]() mutable { self(ec, length); });
#endif
}

} // namespace detail
} // namespace wintls
} // namespace boost

#endif //BOOST_WINTLS_DETAIL_POST_SELF_HPP
67 changes: 24 additions & 43 deletions include/boost/wintls/detail/sspi_handshake.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,11 @@ namespace detail {

class sspi_handshake {
public:
// TODO: enhancement: done_with_data and error_with_data can be removed if
// we move the manual validate logic out of the handshake loop.
enum class state {
data_needed, // data needs to be read from peer
data_available, // data needs to be write to peer
done_with_data, // handshake success, but there is leftover data to be written to peer
error_with_data, // handshake error, but there is leftover data to be written to peer
done, // handshake success
error // handshake error
data_needed, // data needs to be read from peer
data_available, // data needs to be write to peer
done, // handshake success
error // handshake error
};

sspi_handshake(context& context, ctxt_handle& ctxt_handle, cred_handle& cred_handle)
Expand Down Expand Up @@ -73,16 +69,11 @@ class sspi_handshake {
BOOST_UNREACHABLE_RETURN(0);
}();

auto server_cert = context_.server_cert();
if (handshake_type_ == handshake_type::server && server_cert != nullptr) {
creds.cCreds = 1;
creds.paCred = &server_cert;
}

// TODO: rename server_cert field since it is also used for client cert.
// Note: if client cert is set, sspi will auto validate server cert with it.
// Even though verify_server_certificate_ in context is set to false.
if (handshake_type_ == handshake_type::client && server_cert != nullptr) {
auto server_cert = context_.server_cert();
if (server_cert != nullptr) {
creds.cCreds = 1;
creds.paCred = &server_cert;
}
Expand Down Expand Up @@ -129,6 +120,9 @@ class sspi_handshake {
}

state operator()() {
if (last_error_ == SEC_E_OK) {
return state::done;
}
if (last_error_ != SEC_I_CONTINUE_NEEDED && last_error_ != SEC_E_INCOMPLETE_MESSAGE) {
return state::error;
}
Expand Down Expand Up @@ -209,28 +203,20 @@ class sspi_handshake {
return has_buffer_output ? state::data_available : state::data_needed;
}
case SEC_E_OK: {
// sspi handshake ok. perform manual auth here.
manual_auth();
if (handshake_type_ == handshake_type::client) {
if (last_error_ != SEC_E_OK) {
return state::error;
}
} else {
// Note: we are not checking (out_flags & ASC_RET_MUTUAL_AUTH) is true,
// but instead rely on our manual cert validation to establish trust.
// "The AcceptSecurityContext function will return ASC_RET_MUTUAL_AUTH if a
// client certificate was received from the client and schannel was
// successfully able to map the certificate to a user account in AD"
// As observed in tests, this check would wrongly reject openssl client with valid certificate.

// AcceptSecurityContext documentation:
// "If function generated an output token, the token must be sent to the client process."
// This happens when client cert is requested.
if (has_buffer_output) {
return last_error_ == SEC_E_OK ? state::done_with_data : state::error_with_data;
}
}
return state::done;
// sspi handshake ok. Manual authentication will be done after the handshake loop.

// Note: When we requested client auth as a server,
// we are not checking (out_flags & ASC_RET_MUTUAL_AUTH) is true,
// but instead rely on our manual cert validation to establish trust.
// "The AcceptSecurityContext function will return ASC_RET_MUTUAL_AUTH if a
// client certificate was received from the client and schannel was
// successfully able to map the certificate to a user account in AD"
// As observed in tests, this check would wrongly reject openssl client with valid certificate.

// AcceptSecurityContext/InitializeSecurityContext documentation for return value SEC_E_OK:
// "If function generated an output token, the token must be sent to the client/server."
// This happens when client cert is requested.
return has_buffer_output ? state::data_available : state::done;
}

case SEC_I_INCOMPLETE_CREDENTIALS:
Expand Down Expand Up @@ -274,7 +260,6 @@ class sspi_handshake {
check_revocation_ = check;
}

private:
SECURITY_STATUS manual_auth(){
if (!context_.verify_server_certificate_) {
return SEC_E_OK;
Expand All @@ -284,16 +269,12 @@ class sspi_handshake {
if (last_error_ != SEC_E_OK) {
return last_error_;
}

cert_context_ptr remote_cert{ctx_ptr};

last_error_ = static_cast<SECURITY_STATUS>(context_.verify_certificate(remote_cert.get(), server_hostname_, check_revocation_));
if (last_error_ != SEC_E_OK) {
return last_error_;
}
return last_error_;
}

private:
context& context_;
ctxt_handle& ctxt_handle_;
cred_handle& cred_handle_;
Expand Down
Loading