From b510e908d8e3476030492ebb9fd059b50398fbdf Mon Sep 17 00:00:00 2001 From: Vittorio Rigamonti Date: Tue, 21 Jun 2016 12:54:38 +0200 Subject: [PATCH] HRCPP-265 TSL SChannel implementation --- CMakeLists.txt | 107 +- .../hotrod/SslConfigurationBuilder.h | 3 +- .../transport/tcp/SChannelTcpTransport.cpp | 26 + .../impl/transport/tcp/SChannelTcpTransport.h | 34 + src/hotrod/impl/transport/tcp/Socket.cpp | 2 +- .../transport/tcp/TransportObjectFactory.cpp | 9 +- src/hotrod/sys/SChannelSocket.cpp | 1126 +++++++++++++++++ src/hotrod/sys/SChannelSocket.h | 81 ++ test/SimpleSChannel.cpp | 45 + 9 files changed, 1372 insertions(+), 61 deletions(-) create mode 100644 src/hotrod/impl/transport/tcp/SChannelTcpTransport.cpp create mode 100644 src/hotrod/impl/transport/tcp/SChannelTcpTransport.h create mode 100644 src/hotrod/sys/SChannelSocket.cpp create mode 100644 src/hotrod/sys/SChannelSocket.h create mode 100644 test/SimpleSChannel.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f04c7e7..175f4d26 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -306,45 +306,11 @@ if(DEFINED HOTROD_PREBUILT_LIB_DIR) # set(PROTOBUF_PROTOC_LIBRARY ${HR_PROTOBUF_PROTOC_LIBRARY}) #endif("${HR_PROTOBUF_PROTOC_LIBRARY}" STREQUAL "HR_PROTOBUF_PROTOC_LIBRARY-NOTFOUND") - if(MSVC) - find_library(SSLEAY32_LIBRARY NAMES ssleay32 PATHS ${HOTROD_PREBUILT_LIB_DIR}) - if("${SSLEAY32_LIBRARY}" STREQUAL "SSLEAY32_LIBRARY-NOTFOUND") - message(FATAL_ERROR "Cannot find ssleay32 library in directory '${HOTROD_PREBUILT_LIB_DIR}'.") - else("${SSLEAY32_LIBRARY}" STREQUAL "SSLEAY32_LIBRARY-NOTFOUND") - message("-- Found ssleay32 static library: ${SSLEAY32_LIBRARY}") - add_library(ssleay32 STATIC IMPORTED GLOBAL) - set_target_properties(ssleay32 PROPERTIES IMPORTED_LOCATION ${SSLEAY32_LIBRARY}) - endif("${SSLEAY32_LIBRARY}" STREQUAL "SSLEAY32_LIBRARY-NOTFOUND") - - find_library(LIBEAY32_LIBRARY NAMES libeay32 PATHS ${HOTROD_PREBUILT_LIB_DIR}) - if("${LIBEAY32_LIBRARY}" STREQUAL "LIBEAY32_LIBRARY-NOTFOUND") - message(FATAL_ERROR "Cannot find libeay32 library in directory '${HOTROD_PREBUILT_LIB_DIR}'.") - else("${LIBEAY32_LIBRARY}" STREQUAL "LIBEAY32_LIBRARY-NOTFOUND") - message("-- Found libeay32 static library: ${LIBEAY32_LIBRARY}") - add_library(libeay32 STATIC IMPORTED GLOBAL) - set_target_properties(libeay32 PROPERTIES IMPORTED_LOCATION ${LIBEAY32_LIBRARY}) - endif("${LIBEAY32_LIBRARY}" STREQUAL "LIBEAY32_LIBRARY-NOTFOUND") - set(OPENSSL_LIBRARIES ${LIBEAY32_LIBRARY} ${SSLEAY32_LIBRARY}) -find_path(OPENSSL_INCLUDE_DIR - NAMES - openssl/ssl.h - HINTS - ${HOTROD_PREBUILT_LIB_DIR}/.. - PATH_SUFFIXES - include -) - if("${OPENSSL_INCLUDE_DIR}" STREQUAL "OPENSSL_INCLUDE_DIR-NOTFOUND") - message(FATAL_ERROR "Cannot find openssl/ssl.h file in '${HOTROD_PREBUILT_LIB_DIR}/..'.") - else("${OPENSSL_INCLUDE_DIR}" STREQUAL "OPENSSL_INCLUDE_DIR-NOTFOUND") - message("-- Found openssl/ssl.h file in '${OPENSSL_INCLUDE_DIR}'.") - include_directories(${OPENSSL_INCLUDE_DIR}) - endif("${OPENSSL_INCLUDE_DIR}" STREQUAL "OPENSSL_INCLUDE_DIR-NOTFOUND") - else(MSVC) - find_package(OpenSSL) - include_directories(${OPENSSL_INCLUDE_DIR}) - endif(MSVC) + if(NOT DEFINED WIN32) + find_package(OpenSSL) + include_directories(${OPENSSL_INCLUDE_DIR}) + endif(NOT DEFINED WIN32) - if(UNIX) set (PROTOBUF_PROTOC_EXECUTABLE ${HOTROD_PREBUILT_LIB_DIR}/../bin/protoc) else(UNIX) @@ -356,9 +322,10 @@ else(DEFINED HOTROD_PREBUILT_LIB_DIR) find_package(Protobuf REQUIRED) # OpenSSL support -find_package(OpenSSL) -include_directories(${OPENSSL_INCLUDE_DIR}) - +if(NOT DEFINED WIN32) + find_package(OpenSSL) + include_directories(${OPENSSL_INCLUDE_DIR}) +endif(NOT DEFINED WIN32) include_directories ("${CMAKE_CURRENT_SOURCE_DIR}/include" "${CMAKE_CURRENT_SOURCE_DIR}/src" "${CMAKE_CURRENT_BINARY_DIR}") set (INCLUDE_FILES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../include) @@ -379,6 +346,17 @@ set (internal_test_sources src/hotrod/test/HashTest.cpp src/hotrod/test/ConnectionPoolTest.cpp ) +if(WIN32) +set (socket_impl + src/hotrod/impl/transport/tcp/SChannelTcpTransport.cpp + src/hotrod/sys/SChannelSocket.cpp + ) +else(WIN32) +set (socket_impl + src/hotrod/impl/transport/tcp/SSLTcpTransport.cpp + src/hotrod/sys/SSLSocket.cpp + ) +endif(WIN32) endif(ENABLE_INTERNAL_TESTING) set (library_sources @@ -430,14 +408,13 @@ endif(ENABLE_INTERNAL_TESTING) src/hotrod/impl/transport/tcp/ConnectionPool.cpp src/hotrod/impl/transport/tcp/InetSocketAddress.cpp src/hotrod/impl/transport/tcp/Socket.cpp - src/hotrod/impl/transport/tcp/SSLTcpTransport.cpp src/hotrod/impl/transport/tcp/TcpTransport.cpp src/hotrod/impl/transport/tcp/TcpTransportFactory.cpp src/hotrod/impl/transport/tcp/TransportObjectFactory.cpp src/hotrod/impl/transport/tcp/RoundRobinBalancingStrategy.cpp src/hotrod/sys/Runnable.cpp src/hotrod/sys/Log.cpp - src/hotrod/sys/SSLSocket.cpp + ${socket_impl} ${platform_sources} ${internal_test_sources} ${CMAKE_BINARY_DIR}/Version.cpp @@ -559,17 +536,31 @@ set_target_properties(queryTest-static PROPERTIES COMPILE_DEFINITIONS "${DLLEXPO set_target_properties (queryTest-static PROPERTIES COMPILE_FLAGS "${COMPILER_FLAGS} ${WARNING_FLAGS_NO_PEDANTIC} ${STATIC_FLAGS} ${NO_UNUSED_FLAGS}") target_link_libraries (queryTest-static hotrod-static hotrod_protobuf ${PROTOBUF_LIBRARY} ${OPENSSL_LIBRARIES} ${platform_libs}) -add_executable (simple-tls test/SimpleTLS.cpp) -target_include_directories(simple-tls PUBLIC "${CMAKE_CURRENT_BINARY_DIR}/test/query_proto" - "${INCLUDE_FILES_DIR}" - "${INCLUDE_FILES_DIR}/infinispan/hotrod" - "${CMAKE_CURRENT_BINARY_DIR}" - "${PROTOBUF_INCLUDE_DIR}") -set_property(TARGET simple-tls PROPERTY CXX_STANDARD 11) -set_property(TARGET simple-tls PROPERTY CXX_STANDARD_REQUIRED ON) -set_target_properties(simple-tls PROPERTIES COMPILE_DEFINITIONS "${DLLEXPORT_STATIC}") -set_target_properties (simple-tls PROPERTIES COMPILE_FLAGS "${COMPILER_FLAGS} ${WARNING_FLAGS_NO_PEDANTIC} ${NO_UNUSED_FLAGS}") -target_link_libraries (simple-tls hotrod hotrod_protobuf ${PROTOBUF_LIBRARY} ${platform_libs} ${OPENSSL_LIBRARIES}) +if(WIN32) + add_executable (simple-schannel test/SimpleSChannel.cpp) + target_include_directories(simple-schannel PUBLIC "${CMAKE_CURRENT_BINARY_DIR}/test/query_proto" + "${INCLUDE_FILES_DIR}" + "${INCLUDE_FILES_DIR}/infinispan/hotrod" + "${CMAKE_CURRENT_BINARY_DIR}" + "${PROTOBUF_INCLUDE_DIR}") + set_property(TARGET simple-schannel PROPERTY CXX_STANDARD 11) + set_property(TARGET simple-schannel PROPERTY CXX_STANDARD_REQUIRED ON) + set_target_properties(simple-schannel PROPERTIES COMPILE_DEFINITIONS "${DLLEXPORT_STATIC}") + set_target_properties (simple-schannel PROPERTIES COMPILE_FLAGS "${COMPILER_FLAGS} ${WARNING_FLAGS_NO_PEDANTIC} ${NO_UNUSED_FLAGS}") + target_link_libraries (simple-schannel hotrod hotrod_protobuf ${PROTOBUF_LIBRARY} ${platform_libs} ${OPENSSL_LIBRARIES}) +else(WIN32) + add_executable (simple-tls test/SimpleTLS.cpp) + target_include_directories(simple-tls PUBLIC "${CMAKE_CURRENT_BINARY_DIR}/test/query_proto" + "${INCLUDE_FILES_DIR}" + "${INCLUDE_FILES_DIR}/infinispan/hotrod" + "${CMAKE_CURRENT_BINARY_DIR}" + "${PROTOBUF_INCLUDE_DIR}") + set_property(TARGET simple-tls PROPERTY CXX_STANDARD 11) + set_property(TARGET simple-tls PROPERTY CXX_STANDARD_REQUIRED ON) + set_target_properties(simple-tls PROPERTIES COMPILE_DEFINITIONS "${DLLEXPORT_STATIC}") + set_target_properties (simple-tls PROPERTIES COMPILE_FLAGS "${COMPILER_FLAGS} ${WARNING_FLAGS_NO_PEDANTIC} ${NO_UNUSED_FLAGS}") + target_link_libraries (simple-tls hotrod hotrod_protobuf ${PROTOBUF_LIBRARY} ${platform_libs} ${OPENSSL_LIBRARIES}) +endif(WIN32) if (ENABLE_INTERNAL_TESTING) add_executable (unittest test/Unit.cpp) @@ -637,7 +628,11 @@ else (NOT ((EXISTS "${HOTROD_JBOSS_HOME}/bin/standalone.sh") AND (EXISTS "${HOTR add_test (stop_server ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/bin/server_ctl.py stop) add_test (start_ssl_server ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/bin/server_ctl.py start ${JAVA_RUNTIME} ${HOTROD_JBOSS_HOME} standalone-hotrod-ssl.xml) add_test (probe_ssl_port ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/bin/probe_port.py localhost 11222 60) +if(WIN32) + add_test (simple-schannel simple-schannel ${CMAKE_CURRENT_SOURCE_DIR}/test/data/infinispan-ca.pem) +else(WIN32) add_test (simple-tls simple-tls ${CMAKE_CURRENT_SOURCE_DIR}/test/data/infinispan-ca.pem) +endif(WIN32) add_test (stop_ssl_server ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/bin/server_ctl.py stop) endif (NOT ((EXISTS "${HOTROD_JBOSS_HOME}/bin/standalone.sh") AND (EXISTS "${HOTROD_JBOSS_HOME}/bin/standalone.bat"))) @@ -707,10 +702,6 @@ if (NOT WIN32) install (FILES ${libproto_glob} DESTINATION lib${LIB_SUFFIX}) endif (NOT WIN32) -# If on windows, include openssl as well -if(WIN32) - install (DIRECTORY "${OPENSSL_ROOT_DIR}/lib" "${OPENSSL_ROOT_DIR}/bin" "${OPENSSL_ROOT_DIR}/exp" "${OPENSSL_ROOT_DIR}/include" DESTINATION .) -endif(WIN32) if(NOT DEFINED HOTROD_PREBUILT_LIB_DIR) set_property(TARGET hotrod PROPERTY CXX_STANDARD 11) set_property(TARGET hotrod PROPERTY CXX_STANDARD_REQUIRED ON) diff --git a/include/infinispan/hotrod/SslConfigurationBuilder.h b/include/infinispan/hotrod/SslConfigurationBuilder.h index 912ec248..28ed6ab1 100644 --- a/include/infinispan/hotrod/SslConfigurationBuilder.h +++ b/include/infinispan/hotrod/SslConfigurationBuilder.h @@ -93,11 +93,12 @@ class SslConfigurationBuilder return *this; } - private: +private: bool m_enabled; std::string m_serverCAPath; std::string m_serverCAFile; std::string m_clientCertificateFile; + }; }} diff --git a/src/hotrod/impl/transport/tcp/SChannelTcpTransport.cpp b/src/hotrod/impl/transport/tcp/SChannelTcpTransport.cpp new file mode 100644 index 00000000..d2d27634 --- /dev/null +++ b/src/hotrod/impl/transport/tcp/SChannelTcpTransport.cpp @@ -0,0 +1,26 @@ +#include +#include "infinispan/hotrod/exceptions.h" +#include "hotrod/impl/transport/tcp/TcpTransportFactory.h" +#include "hotrod/sys/SChannelSocket.h" + +#include +#include +#include +#include + + +namespace infinispan { + namespace hotrod { + namespace transport { + + SChannelTcpTransport::SChannelTcpTransport( + const InetSocketAddress& a, TransportFactory& factory, const std::string& _serverCAPath, const std::string& _serverCAFile, const std::string& _clientCertificateFile) + : TcpTransport(a, factory, sys::SChannelSocket::create(_serverCAPath, _serverCAFile, _clientCertificateFile)) { + } + + //Testing purpose only! + SChannelTcpTransport::SChannelTcpTransport() + : TcpTransport(InetSocketAddress(), *(TransportFactory*)0, sys::Socket::create()) {} + } + } +} \ No newline at end of file diff --git a/src/hotrod/impl/transport/tcp/SChannelTcpTransport.h b/src/hotrod/impl/transport/tcp/SChannelTcpTransport.h new file mode 100644 index 00000000..84e17305 --- /dev/null +++ b/src/hotrod/impl/transport/tcp/SChannelTcpTransport.h @@ -0,0 +1,34 @@ +#ifndef ISPN_HOTROD_SCHANNELTCPTRANSPORT_H +#define ISPN_HOTROD_SCHANNELTCPTRANSPORT_H + +#include +#include + +#include "infinispan/hotrod/defs.h" +#include "hotrod/impl/transport/AbstractTransport.h" +#include "hotrod/impl/transport/tcp/Socket.h" +#include "TcpTransport.h" + +namespace infinispan { + namespace hotrod { + namespace transport { + + class TransportFactory; + + class SChannelTcpTransport : public TcpTransport + { + public: + SChannelTcpTransport(const InetSocketAddress& address, TransportFactory& factory, const std::string& _serverCAPath, const std::string& _serverCAFile, const std::string& _clientCertificateFile); + + protected: + //Testing only! + SChannelTcpTransport(); + + friend class TcpTransportFactory; + }; + + } + } +} // namespace infinispan::hotrod::transport::tcp + +#endif /* ISPN_HOTROD_SCHANNELTCPTRANSPORT_H */ diff --git a/src/hotrod/impl/transport/tcp/Socket.cpp b/src/hotrod/impl/transport/tcp/Socket.cpp index 39a56141..035d4bb0 100644 --- a/src/hotrod/impl/transport/tcp/Socket.cpp +++ b/src/hotrod/impl/transport/tcp/Socket.cpp @@ -53,7 +53,7 @@ void InputStream::read(char *buf, size_t size) { tmp_buffer += capacity; size -= capacity; } - capacity = socket.read(&buffer[0], BufferSize); + capacity = socket.read(&buffer[0], size < BufferSize ? size : BufferSize); ptr = &buffer[0]; hasMore = capacity < BufferSize ? false : true; } diff --git a/src/hotrod/impl/transport/tcp/TransportObjectFactory.cpp b/src/hotrod/impl/transport/tcp/TransportObjectFactory.cpp index bbfb02ce..0030311c 100644 --- a/src/hotrod/impl/transport/tcp/TransportObjectFactory.cpp +++ b/src/hotrod/impl/transport/tcp/TransportObjectFactory.cpp @@ -1,4 +1,5 @@ #include +#include #include "hotrod/impl/transport/tcp/TransportObjectFactory.h" #include "hotrod/impl/transport/tcp/TcpTransport.h" #include "hotrod/impl/transport/tcp/TcpTransportFactory.h" @@ -20,9 +21,15 @@ TransportObjectFactory::TransportObjectFactory( TcpTransport& TransportObjectFactory::makeObject(const InetSocketAddress& address) { if(tcpTransportFactory.isSslEnabled()) { +#if !defined _WIN32 && !defined _WIN64 return *(new SSLTcpTransport(address, tcpTransportFactory, tcpTransportFactory.getSslServerCAPath(), tcpTransportFactory.getSslServerCAFile(), tcpTransportFactory.getSslClientCertificateFile())); - } else { +#else +return *(new SChannelTcpTransport(address, tcpTransportFactory, + tcpTransportFactory.getSslServerCAPath(), tcpTransportFactory.getSslServerCAFile(), tcpTransportFactory.getSslClientCertificateFile())); +#endif + + } else { return *(new TcpTransport(address, tcpTransportFactory)); } } diff --git a/src/hotrod/sys/SChannelSocket.cpp b/src/hotrod/sys/SChannelSocket.cpp new file mode 100644 index 00000000..60cc8b3b --- /dev/null +++ b/src/hotrod/sys/SChannelSocket.cpp @@ -0,0 +1,1126 @@ +#include +#include + +#include "infinispan/hotrod/exceptions.h" +#include "hotrod/sys/Log.h" +#include "SChannelSocket.h" +#include "WinDef.h" + + +#define IO_BUFFER_SIZE 0x10000 + +#pragma comment(lib, "Crypt32.lib") +#pragma comment(lib, "WSock32.Lib") +#pragma comment(lib, "Crypt32.Lib") +#pragma comment(lib, "user32.lib") +#pragma comment(lib, "MSVCRTD.lib") +#pragma comment(lib, "Secur32.lib") + +using namespace infinispan::hotrod::sys; +HCERTSTORE SChannelSocket::hMyCertStore = NULL; + + +SChannelSocket::SChannelInitializer SChannelSocket::initializer; + +SChannelSocket::SChannelInitializer::SChannelInitializer() { + WSADATA wsaData; + g_pSSPI=InitSecurityInterface(); + if (g_pSSPI == NULL) + { + logAndThrow("", 0, "InitSecurityInferface: initialization failure"); + } + if (WSAStartup(0x0101, &wsaData)) + { + logAndThrow("",0,"Could not initialize winsock"); + } + +} + +SChannelSocket::SChannelSocket(const std::string& _serverCAPath, const std::string& _serverCAFile, const std::string& _clientCertificateFile) : + m_serverCAPath(_serverCAPath), m_serverCAFile(_serverCAFile), m_clientCertificateFile(_clientCertificateFile) + {} + +SChannelSocket::~SChannelSocket() +{ + cleanup(); +} + +INT SChannelSocket::connectToServer(std::string host, int iPortNumber, SOCKET * pSocket) +{ + SOCKET Socket; + struct sockaddr_in sin; + struct hostent *hp; + auto pszServerName = (LPSTR)host.c_str(); + + Socket = socket(PF_INET, SOCK_STREAM, 0); + if (Socket == INVALID_SOCKET) + { + ERROR("**** Error %d creating socket\n", WSAGetLastError()); + displayWinSockError(WSAGetLastError()); + return WSAGetLastError(); + } + sin.sin_family = AF_INET; + sin.sin_port = htons((u_short)iPortNumber); + if ((hp = gethostbyname(pszServerName)) == NULL) + { + ERROR("**** Error returned by gethostbyname\n"); + displayWinSockError(WSAGetLastError()); + return WSAGetLastError(); + } + else + { + memcpy(&sin.sin_addr, hp->h_addr, 4); + } + if (::connect(Socket, (struct sockaddr *)&sin, sizeof(sin)) == SOCKET_ERROR) + { + ERROR("**** Error %d connecting to \"%s\" (%s)\n", WSAGetLastError(), pszServerName, inet_ntoa(sin.sin_addr)); + closesocket(Socket); + displayWinSockError(WSAGetLastError()); + return WSAGetLastError(); + } + *pSocket = Socket; + return SEC_E_OK; +} + +void SChannelSocket::getNewClientCredentials() +{ + + CredHandle newHCreds; + SecPkgContext_IssuerListInfoEx IssuerListInfo; + PCCERT_CHAIN_CONTEXT pChainContext; + CERT_CHAIN_FIND_BY_ISSUER_PARA FindByIssuerPara; + PCCERT_CONTEXT pCertContext; + TimeStamp tsExpiry; + SECURITY_STATUS Status; + SCHANNEL_CRED SchannelCred; + + // Read list of trusted issuers from schannel. + Status = initializer.g_pSSPI->QueryContextAttributes(&hContext, SECPKG_ATTR_ISSUER_LIST_EX, (PVOID)&IssuerListInfo); + if (Status != SEC_E_OK) + { + ERROR("Error 0x%x querying issuer list info\n", Status); return; + } + + // Enumerate the client certificates. + ZeroMemory(&FindByIssuerPara, sizeof(FindByIssuerPara)); + + FindByIssuerPara.cbSize = sizeof(FindByIssuerPara); + FindByIssuerPara.pszUsageIdentifier = szOID_PKIX_KP_CLIENT_AUTH; + FindByIssuerPara.dwKeySpec = 0; + FindByIssuerPara.cIssuer = IssuerListInfo.cIssuers; + FindByIssuerPara.rgIssuer = IssuerListInfo.aIssuers; + pChainContext = NULL; + while (TRUE) + { // Find a certificate chain. + pChainContext = CertFindChainInStore(hMyCertStore, + X509_ASN_ENCODING, + 0, + CERT_CHAIN_FIND_BY_ISSUER, + &FindByIssuerPara, + pChainContext); + if (pChainContext == NULL) + { + ERROR("Error 0x%x finding cert chain\n", GetLastError()); break; + } + DEBUG("\ncertificate chain found\n"); + + // Get pointer to leaf certificate context. + pCertContext = pChainContext->rgpChain[0]->rgpElement[0]->pCertContext; + + // Create schannel credential. + SchannelCred.dwVersion = SCHANNEL_CRED_VERSION; + SchannelCred.cCreds = 1; + SchannelCred.paCred = &pCertContext; + + Status = initializer.g_pSSPI->AcquireCredentialsHandleA(NULL, // Name of principal + UNISP_NAME_A, // Name of package + SECPKG_CRED_OUTBOUND, // Flags indicating use + NULL, // Pointer to logon ID + &SchannelCred, // Package specific data + NULL, // Pointer to GetKey() func + NULL, // Value to pass to GetKey() + &newHCreds, // (out) Cred Handle + &tsExpiry); // (out) Lifetime (optional) + isCredsInitialized = true; + if (Status != SEC_E_OK) + { + ERROR("**** Error 0x%x returned by AcquireCredentialsHandle\n", Status); + continue; + } + DEBUG("\nnew schannel credential created\n"); + initializer.g_pSSPI->FreeCredentialsHandle(&hCred); // Destroy the old credentials. + hCred = newHCreds; + } +} + +SECURITY_STATUS SChannelSocket::clientHandshakeLoop(BOOL doInitialRead, SecBuffer *pExtraData) +{ + + SecBufferDesc OutBuffer, InBuffer; + SecBuffer InBuffers[2], OutBuffers[1]; + DWORD dwSSPIFlags, dwSSPIOutFlags, cbData, cbIoBuffer; + TimeStamp tsExpiry; + SECURITY_STATUS scRet; + BOOL fDoRead; + + static UCHAR IoBuffer[IO_BUFFER_SIZE]; + + dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | + ISC_RET_EXTENDED_ERROR | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM; + + cbIoBuffer = 0; + fDoRead = doInitialRead; + + // Loop until the handshake is finished or an error occurs. + scRet = SEC_I_CONTINUE_NEEDED; + while (scRet == SEC_I_CONTINUE_NEEDED || + scRet == SEC_E_INCOMPLETE_MESSAGE || + scRet == SEC_I_INCOMPLETE_CREDENTIALS) + { + if (0 == cbIoBuffer || scRet == SEC_E_INCOMPLETE_MESSAGE) // Read data from server. + { + if (fDoRead) + { + cbData = recv(Client_Socket, (PCHAR)IoBuffer + cbIoBuffer, IO_BUFFER_SIZE - cbIoBuffer, 0); + if (cbData == SOCKET_ERROR) + { + ERROR("**** Error %d reading data from server\n", WSAGetLastError()); + scRet = SEC_E_INTERNAL_ERROR; + break; + } + else if (cbData == 0) + { + ERROR("**** Server unexpectedly disconnected\n"); + scRet = SEC_E_INTERNAL_ERROR; + break; + } + DEBUG("%d bytes of handshake data received\n", cbData); + cbIoBuffer += cbData; + } + else + fDoRead = TRUE; + } + + // Set up the input buffers. Buffer 0 is used to pass in data + // received from the server. Schannel will consume some or all + // of this. Leftover data (if any) will be placed in buffer 1 and + // given a buffer type of SECBUFFER_EXTRA. + InBuffers[0].pvBuffer = IoBuffer; + InBuffers[0].cbBuffer = cbIoBuffer; + InBuffers[0].BufferType = SECBUFFER_TOKEN; + + InBuffers[1].pvBuffer = NULL; + InBuffers[1].cbBuffer = 0; + InBuffers[1].BufferType = SECBUFFER_EMPTY; + + InBuffer.cBuffers = 2; + InBuffer.pBuffers = InBuffers; + InBuffer.ulVersion = SECBUFFER_VERSION; + + + // Set up the output buffers. These are initialized to NULL + // so as to make it less likely we'll attempt to free random + // garbage later. + OutBuffers[0].pvBuffer = NULL; + OutBuffers[0].BufferType = SECBUFFER_TOKEN; + OutBuffers[0].cbBuffer = 0; + + OutBuffer.cBuffers = 1; + OutBuffer.pBuffers = OutBuffers; + OutBuffer.ulVersion = SECBUFFER_VERSION; + + // Call InitializeSecurityContext. + scRet = initializer.g_pSSPI->InitializeSecurityContextA(&hCred, + &hContext, + NULL, + dwSSPIFlags, + 0, + SECURITY_NATIVE_DREP, + &InBuffer, + 0, + NULL, + &OutBuffer, + &dwSSPIOutFlags, + &tsExpiry); + + + // If InitializeSecurityContext was successful (or if the error was + // one of the special extended ones), send the contends of the output + // buffer to the server. + if (scRet == SEC_E_OK || + scRet == SEC_I_CONTINUE_NEEDED || + FAILED(scRet) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR)) + { + if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL) + { + cbData = send(Client_Socket, (const PCHAR)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0); + if (cbData == SOCKET_ERROR || cbData == 0) + { + ERROR("**** Error %d sending data to server\n", WSAGetLastError()); + displayWinSockError(WSAGetLastError()); + initializer.g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer); + initializer.g_pSSPI->DeleteSecurityContext(&hContext); + return SEC_E_INTERNAL_ERROR; + } + // Free output buffer. + initializer.g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer); + OutBuffers[0].pvBuffer = NULL; + } + } + + // If InitializeSecurityContext returned SEC_E_INCOMPLETE_MESSAGE, + // then we need to read more data from the server and try again. + if (scRet == SEC_E_INCOMPLETE_MESSAGE) continue; + + // If InitializeSecurityContext returned SEC_E_OK, then the + // handshake completed successfully. + if (scRet == SEC_E_OK) + { + // If the "extra" buffer contains data, this is encrypted application + // protocol layer stuff. It needs to be saved. The application layer + // will later decrypt it with DecryptMessage. + DEBUG("Handshake was successful\n"); + + if (InBuffers[1].BufferType == SECBUFFER_EXTRA) + { + pExtraData->pvBuffer = LocalAlloc(LMEM_FIXED, InBuffers[1].cbBuffer); + if (pExtraData->pvBuffer == NULL) + { + ERROR("**** Out of memory (2)\n"); + return SEC_E_INTERNAL_ERROR; + } + + MoveMemory(pExtraData->pvBuffer, + IoBuffer + (cbIoBuffer - InBuffers[1].cbBuffer), + InBuffers[1].cbBuffer); + pExtraData->cbBuffer = InBuffers[1].cbBuffer; + pExtraData->BufferType = SECBUFFER_TOKEN; + DEBUG("%d bytes of app data was bundled with handshake data\n", pExtraData->cbBuffer); + } + else + { + pExtraData->pvBuffer = NULL; + pExtraData->cbBuffer = 0; + pExtraData->BufferType = SECBUFFER_EMPTY; + } + break; // Bail out to quit + } + + // Check for fatal error. + if (FAILED(scRet)) + { + ERROR("**** Error 0x%x returned by InitializeSecurityContext (2)\n", scRet); + break; + } + + // If InitializeSecurityContext returned SEC_I_INCOMPLETE_CREDENTIALS, + // then the server just requested client authentication. + if (scRet == SEC_I_INCOMPLETE_CREDENTIALS) + { + // Busted. The server has requested client authentication and + // the credential we supplied didn't contain a client certificate. + // This function will read the list of trusted certificate + // authorities ("issuers") that was received from the server + // and attempt to find a suitable client certificate that + // was issued by one of these. If this function is successful, + // then we will connect using the new certificate. Otherwise, + // we will attempt to connect anonymously (using our current credentials). + getNewClientCredentials(); + // Go around again. + fDoRead = FALSE; + scRet = SEC_I_CONTINUE_NEEDED; + continue; + } + + // Copy any leftover data from the "extra" buffer, and go around again. + if (InBuffers[1].BufferType == SECBUFFER_EXTRA) + { + MoveMemory(IoBuffer, IoBuffer + (cbIoBuffer - InBuffers[1].cbBuffer), InBuffers[1].cbBuffer); + cbIoBuffer = InBuffers[1].cbBuffer; + } + else + cbIoBuffer = 0; + } + + // Delete the security context in the case of a fatal error. + if (FAILED(scRet)) initializer.g_pSSPI->DeleteSecurityContext(&hContext); + return scRet; +} + +SECURITY_STATUS SChannelSocket::performClientHandshake(std::string host, SecBuffer* pExtraData) +{ + SecBuffer OutBuffers[1]; + DWORD dwSSPIFlags, dwSSPIOutFlags, cbData; + TimeStamp tsExpiry; + SECURITY_STATUS scRet; + SecBufferDesc OutBuffer; + auto pszServerName = (LPSTR)host.c_str(); + + dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY | + ISC_RET_EXTENDED_ERROR | ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_STREAM; + + // Initiate a ClientHello message and generate a token. + OutBuffers[0].pvBuffer = NULL; + OutBuffers[0].BufferType = SECBUFFER_TOKEN; + OutBuffers[0].cbBuffer = 0; + + OutBuffer.cBuffers = 1; + OutBuffer.pBuffers = OutBuffers; + OutBuffer.ulVersion = SECBUFFER_VERSION; + + scRet = initializer.g_pSSPI->InitializeSecurityContextA(&hCred, + NULL, + pszServerName, + dwSSPIFlags, + 0, + SECURITY_NATIVE_DREP, + NULL, + 0, + &hContext, + &OutBuffer, + &dwSSPIOutFlags, + &tsExpiry); + + if (scRet != SEC_I_CONTINUE_NEEDED) + { + ERROR("**** Error %d returned by InitializeSecurityContext (1)\n", scRet); + return scRet; + } + + // Send response to server if there is one. + if (OutBuffers[0].cbBuffer != 0 && OutBuffers[0].pvBuffer != NULL) + { + cbData = send(Client_Socket, (const PCHAR)OutBuffers[0].pvBuffer, OutBuffers[0].cbBuffer, 0); + if (cbData == SOCKET_ERROR || cbData == 0) + { + ERROR("**** Error %d sending data to server (1)\n", WSAGetLastError()); + initializer.g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer); + initializer.g_pSSPI->DeleteSecurityContext(&hContext); + return SEC_E_INTERNAL_ERROR; + } + DEBUG("%d bytes of handshake data sent\n", cbData); + //{ PrintHexDump(cbData, OutBuffers[0].pvBuffer); printf("\n"); } + initializer.g_pSSPI->FreeContextBuffer(OutBuffers[0].pvBuffer); // Free output buffer. + OutBuffers[0].pvBuffer = NULL; + } + + return clientHandshakeLoop(TRUE, pExtraData); +} + +void SChannelSocket::displayWinVerifyTrustError(DWORD Status) +{ + LPSTR pszName = NULL; + + switch (Status) + { + case CERT_E_EXPIRED: pszName = "CERT_E_EXPIRED"; break; + case CERT_E_VALIDITYPERIODNESTING: pszName = "CERT_E_VALIDITYPERIODNESTING"; break; + case CERT_E_ROLE: pszName = "CERT_E_ROLE"; break; + case CERT_E_PATHLENCONST: pszName = "CERT_E_PATHLENCONST"; break; + case CERT_E_CRITICAL: pszName = "CERT_E_CRITICAL"; break; + case CERT_E_PURPOSE: pszName = "CERT_E_PURPOSE"; break; + case CERT_E_ISSUERCHAINING: pszName = "CERT_E_ISSUERCHAINING"; break; + case CERT_E_MALFORMED: pszName = "CERT_E_MALFORMED"; break; + case CERT_E_UNTRUSTEDROOT: pszName = "CERT_E_UNTRUSTEDROOT"; break; + case CERT_E_CHAINING: pszName = "CERT_E_CHAINING"; break; + case TRUST_E_FAIL: pszName = "TRUST_E_FAIL"; break; + case CERT_E_REVOKED: pszName = "CERT_E_REVOKED"; break; + case CERT_E_UNTRUSTEDTESTROOT: pszName = "CERT_E_UNTRUSTEDTESTROOT"; break; + case CERT_E_REVOCATION_FAILURE: pszName = "CERT_E_REVOCATION_FAILURE"; break; + case CERT_E_CN_NO_MATCH: pszName = "CERT_E_CN_NO_MATCH"; break; + case CERT_E_WRONG_USAGE: pszName = "CERT_E_WRONG_USAGE"; break; + default: pszName = "(unknown)"; break; + } + ERROR("Error 0x%x (%s) returned by CertVerifyCertificateChainPolicy!\n", Status, pszName); +} + + +DWORD SChannelSocket::verifyServerCertificate(PCCERT_CONTEXT pServerCert, std::string host, DWORD dwCertFlags) +{ + HTTPSPolicyCallbackData polHttps; + CERT_CHAIN_POLICY_PARA PolicyPara; + CERT_CHAIN_POLICY_STATUS PolicyStatus; + CERT_CHAIN_PARA ChainPara; + PCCERT_CHAIN_CONTEXT pChainContext = NULL; + DWORD cchServerName, Status; + LPSTR rgszUsages[] = { szOID_PKIX_KP_SERVER_AUTH, + szOID_SERVER_GATED_CRYPTO, + szOID_SGC_NETSCAPE }; + auto pszServerName = (PSTR)host.c_str(); + DWORD cUsages = sizeof(rgszUsages) / sizeof(LPSTR); + PWSTR pwszServerName = NULL; + if (pServerCert == NULL) + { + if (pChainContext) CertFreeCertificateChain(pChainContext); + return SEC_E_WRONG_PRINCIPAL; + } + + // Convert server name to unicode. + if (pszServerName == NULL || strlen(pszServerName) == 0) + { + return SEC_E_WRONG_PRINCIPAL; + } + + cchServerName = MultiByteToWideChar(CP_ACP, 0, pszServerName, -1, NULL, 0); + pwszServerName = (PWSTR)LocalAlloc(LMEM_FIXED, cchServerName * sizeof(WCHAR)); + if (pwszServerName == NULL) + { + return SEC_E_INSUFFICIENT_MEMORY; + } + + cchServerName = MultiByteToWideChar(CP_ACP, 0, pszServerName, -1, pwszServerName, cchServerName); + if (cchServerName == 0) + { + if (pwszServerName) LocalFree(pwszServerName); + return SEC_E_WRONG_PRINCIPAL; + } + + // Build certificate chain. + ZeroMemory(&ChainPara, sizeof(ChainPara)); + ChainPara.cbSize = sizeof(ChainPara); + ChainPara.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR; + ChainPara.RequestedUsage.Usage.cUsageIdentifier = cUsages; + ChainPara.RequestedUsage.Usage.rgpszUsageIdentifier = rgszUsages; + + if (!CertGetCertificateChain(NULL, + pServerCert, + NULL, + pServerCert->hCertStore, + &ChainPara, + 0, + NULL, + &pChainContext)) + { + Status = GetLastError(); + ERROR("Error 0x%x returned by CertGetCertificateChain!\n", Status); + if (pChainContext) CertFreeCertificateChain(pChainContext); + if (pwszServerName) LocalFree(pwszServerName); + return Status; + } + + // Validate certificate chain. + ZeroMemory(&polHttps, sizeof(HTTPSPolicyCallbackData)); + polHttps.cbStruct = sizeof(HTTPSPolicyCallbackData); + polHttps.dwAuthType = AUTHTYPE_SERVER; + polHttps.fdwChecks = dwCertFlags; + polHttps.pwszServerName = pwszServerName; + + memset(&PolicyPara, 0, sizeof(PolicyPara)); + PolicyPara.cbSize = sizeof(PolicyPara); + PolicyPara.pvExtraPolicyPara = &polHttps; + + memset(&PolicyStatus, 0, sizeof(PolicyStatus)); + PolicyStatus.cbSize = sizeof(PolicyStatus); + + if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, + pChainContext, + &PolicyPara, + &PolicyStatus)) + { + Status = GetLastError(); + ERROR("Error 0x%x returned by CertVerifyCertificateChainPolicy!\n", Status); + if (pChainContext) CertFreeCertificateChain(pChainContext); + if (pwszServerName) LocalFree(pwszServerName); + return Status; + } + + if (PolicyStatus.dwError) + { + Status = PolicyStatus.dwError; + displayWinVerifyTrustError(Status); + if (pChainContext) CertFreeCertificateChain(pChainContext); + if (pwszServerName) LocalFree(pwszServerName); + return Status; + } + + Status = SEC_E_OK; + if (pChainContext) CertFreeCertificateChain(pChainContext); + if (pwszServerName) LocalFree(pwszServerName); + return Status; +} + +void SChannelSocket::connect(const std::string & host, int port, int timeout) +{ + SCHANNEL_CRED SchannelCred; + DWORD dwProtocol = SP_PROT_TLS1; // SP_PROT_TLS1; // SP_PROT_PCT1; SP_PROT_SSL2; SP_PROT_SSL3; 0=default + ALG_ID aiKeyExch = 0; // = default; CALG_DH_EPHEM; CALG_RSA_KEYX; + SECURITY_STATUS Status; + DWORD cSupportedAlgs = 0; + ALG_ID rgbSupportedAlgs[16]; + TimeStamp tsExpiry; + SecBuffer ExtraData; + std::ifstream is( m_serverCAFile.c_str(), std::ios::binary); + if (is) { + is.seekg(0, is.end); + int length = is.tellg(); + is.seekg(0, is.beg); + BYTE * certificate = new BYTE[length]; + is.read((PCHAR)certificate, length); + PCCERT_CONTEXT pCertContext; + + // If a user name is specified, then attempt to find a client + // certificate. Otherwise, just create a NULL credential. + if (!m_clientCertificateFile.empty()) + { + // Open the "MY" certificate store, where IE stores client certificates. + // Windows maintains 4 stores -- MY, CA, ROOT, SPC. + if (hMyCertStore == NULL) + { + hMyCertStore = CertOpenSystemStore(0, "MY"); + if (!hMyCertStore) + { + printf("**** Error 0x%x returned by CertOpenSystemStore\n", GetLastError()); + logAndThrow(host, port, "ERROR"); + } + } + + // Find client certificate. Note that this sample just searches for a + // certificate that contains the user name somewhere in the subject name. + // A real application should be a bit less casual. + pCertContext = CertFindCertificateInStore(hMyCertStore, // hCertStore + X509_ASN_ENCODING, // dwCertEncodingType + 0, // dwFindFlags + CERT_FIND_SUBJECT_STR_A + ,// dwFindType + m_clientCertificateFile.c_str(), // *pvFindPara + NULL); // pPrevCertContext + + if (pCertContext == NULL) + { + ERROR("**** Error 0x%x returned by CertFindCertificateInStore\n", GetLastError()); + if (GetLastError() == CRYPT_E_NOT_FOUND) + ERROR("CRYPT_E_NOT_FOUND - property doesn't exist\n"); + logAndThrow(host, port, "ERROR"); + } + } + else + { + pCertContext = CertCreateCertificateContext(X509_ASN_ENCODING, certificate, length); + } + ZeroMemory(&SchannelCred, sizeof(SchannelCred)); + SchannelCred.dwVersion = SCHANNEL_CRED_VERSION; + if (pCertContext) + { + SchannelCred.cCreds = 1; + SchannelCred.paCred = &pCertContext; + } + + SchannelCred.grbitEnabledProtocols = dwProtocol; + + if (aiKeyExch) rgbSupportedAlgs[cSupportedAlgs++] = aiKeyExch; + + if (cSupportedAlgs) + { + SchannelCred.cSupportedAlgs = cSupportedAlgs; + SchannelCred.palgSupportedAlgs = rgbSupportedAlgs; + } + + SchannelCred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; + + // The SCH_CRED_MANUAL_CRED_VALIDATION flag is specified because + // this sample verifies the server certificate manually. + // Applications that expect to run on WinNT, Win9x, or WinME + // should specify this flag and also manually verify the server + // certificate. Applications running on newer versions of Windows can + // leave off this flag, in which case the InitializeSecurityContext + // function will validate the server certificate automatically. + SchannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION; + + // Create an SSPI credential. + Status = SChannelSocket::initializer.g_pSSPI->AcquireCredentialsHandleA(NULL, // Name of principal + UNISP_NAME_A, // Name of package + SECPKG_CRED_OUTBOUND, // Flags indicating use + NULL, // Pointer to logon ID + &SchannelCred, // Package specific data + NULL, // Pointer to GetKey() func + NULL, // Value to pass to GetKey() + &hCred, // (out) Cred Handle + &tsExpiry); // (out) Lifetime (optional) + + if (Status != SEC_E_OK) + { + ERROR("**** Error 0x%x returned by AcquireCredentialsHandle\n", Status); + } + if (pCertContext) CertFreeCertificateContext(pCertContext); + + Client_Socket = INVALID_SOCKET; + connectToServer(host, port, &Client_Socket); + performClientHandshake(host, &ExtraData); + isContextInitialized = true; + // Authenticate server's credentials. Get server's certificate. + Status = initializer.g_pSSPI->QueryContextAttributes(&hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID)&pRemoteCertContext); + if (Status != SEC_E_OK) + { + cleanup(); + ERROR("Error 0x%x querying remote certificate\n", Status); + logAndThrow(host, port, "Error querying remote certificate\n"); + } // + DEBUG("----- Server Credentials Authenticated \n"); + // Attempt to validate server certificate. + if (this->onlyVerified) + { + Status = verifyServerCertificate(pRemoteCertContext, host, 0); + if (Status) + { + ERROR("**** Error 0x%x authenticating server credentials!\n", Status); + cleanup(); + logAndThrow(host,port,"**** Error 0x%x authenticating server credentials!\n"); + } + // The server certificate did not validate correctly. At this point, we cannot tell + // if we are connecting to the correct server, or if we are connecting to a + // "man in the middle" attack server - Best to just abort the connection. + DEBUG("----- Server Certificate Verified\n"); + // Free the server certificate context. + } + CertFreeCertificateContext(pRemoteCertContext); + pRemoteCertContext = NULL; + DEBUG("----- Server certificate context released \n"); + + // Display connection info. + displayConnectionInfo(); + DEBUG("----- Secure Connection Info\n"); + SecPkgContext_StreamSizes Sizes; + initializer.g_pSSPI->QueryContextAttributes(&hContext, SECPKG_ATTR_STREAM_SIZES, &Sizes); + auto cbIoBufferLength = Sizes.cbHeader + Sizes.cbMaximumMessage + Sizes.cbTrailer; + pbRBuffer = (PBYTE)LocalAlloc(LMEM_FIXED, cbIoBufferLength); + pbWBuffer = (PBYTE)LocalAlloc(LMEM_FIXED, cbIoBufferLength); + recvBuff = (PBYTE)LocalAlloc(LMEM_FIXED, cbIoBufferLength); + this->cbRBuffer = 0; + if (ExtraData.cbBuffer > 0) + { + MoveMemory(recvBuff, ExtraData.pvBuffer, ExtraData.cbBuffer); + this->cbRBuffer = ExtraData.cbBuffer; + LocalFree(ExtraData.pvBuffer); + } + } +} + +void SChannelSocket::close() +{ + cleanup(); +} + +void SChannelSocket::setTcpNoDelay(bool tcpNoDelay) +{ +} + +void SChannelSocket::setTimeout(int timeout) +{ +} + +void throwIOErr(const std::string& host, int port, const PCHAR msg, int errnum) { + std::ostringstream m; + m << msg; + m << " (host: " << host; + m << " port: " << port << ")"; + + if (errnum != 0) { + CHAR buf[200]; + if (FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, NULL, errnum, 0, buf, 200, NULL) != 0) { + m << " " << buf; + } + else { + m << " Unknown WSA error " << errnum; + } + } + throw std::exception(m.str().c_str()); +} + +SECURITY_STATUS SChannelSocket::readDecrypt(const DWORD bufsize, size_t *read_counter) +{ + SecBuffer ExtraBuffer; + SecBuffer *pDataBuffer, *pExtraBuffer; + + SECURITY_STATUS scRet; + SecBufferDesc Message; + SecBuffer Buffers[4]; + DWORD cbData, length; + int i; + + // Read data from server until done. + *read_counter = 0; + scRet = SEC_E_INCOMPLETE_MESSAGE; + do + { + if (cbRBuffer == 0 || scRet == SEC_E_INCOMPLETE_MESSAGE) // get the data + { + cbData = recv(Client_Socket, (PCHAR) recvBuff, bufsize, 0); + if (cbData == SOCKET_ERROR) + { + ERROR("**** Error %d reading data from server\n", WSAGetLastError()); + scRet = SEC_E_INTERNAL_ERROR; + break; + } + else if (cbData == 0) // Server disconnected. + { + if (cbRBuffer) + { + ERROR("**** Server unexpectedly disconnected\n"); + scRet = SEC_E_INTERNAL_ERROR; + return scRet; + } + else + break; // All Done + } + else // success + { + DEBUG("%d bytes of (encrypted) application data received\n", cbData); + cbRBuffer += cbData; + } + } + + // Decrypt the received data. + Buffers[0].pvBuffer = recvBuff; + Buffers[0].cbBuffer = cbRBuffer; + Buffers[0].BufferType = SECBUFFER_DATA; // Initial Type of the buffer 1 + Buffers[1].BufferType = SECBUFFER_EMPTY; // Initial Type of the buffer 2 + Buffers[2].BufferType = SECBUFFER_EMPTY; // Initial Type of the buffer 3 + Buffers[3].BufferType = SECBUFFER_EMPTY; // Initial Type of the buffer 4 + + Message.ulVersion = SECBUFFER_VERSION; // Version number + Message.cBuffers = 4; // Number of buffers - must contain four SecBuffer structures. + Message.pBuffers = Buffers; // Pointer to array of buffers + scRet = initializer.g_pSSPI->DecryptMessage(&hContext, &Message, 0, NULL); + if (scRet == SEC_I_CONTEXT_EXPIRED) break; // Server signalled end of session +// if( scRet == SEC_E_INCOMPLETE_MESSAGE - Input buffer has partial encrypted record, read more + if (scRet != SEC_E_OK && + scRet != SEC_I_RENEGOTIATE && + scRet != SEC_I_CONTEXT_EXPIRED) + { + ERROR("**** Error DecryptMessage "); + //DisplaySECError((DWORD)scRet); + return scRet; + } + + // Locate data and (optional) extra buffers. + pDataBuffer = NULL; + pExtraBuffer = NULL; + for (i = 0; i < 4; i++) + { + if (pDataBuffer == NULL && Buffers[i].BufferType == SECBUFFER_DATA) pDataBuffer = &Buffers[i]; + if (pExtraBuffer == NULL && Buffers[i].BufferType == SECBUFFER_EXTRA) pExtraBuffer = &Buffers[i]; + } + + MoveMemory(pbRBuffer+*read_counter, pDataBuffer->pvBuffer, pDataBuffer->cbBuffer); + *read_counter += pDataBuffer->cbBuffer; + + // Move any "extra" data to the input buffer. + if (pExtraBuffer) + { + MoveMemory(recvBuff, pExtraBuffer->pvBuffer, pExtraBuffer->cbBuffer); + cbRBuffer = pExtraBuffer->cbBuffer; // printf("cbIoBuffer= %d \n", cbIoBuffer); + + } + else + cbRBuffer = 0; + + // The server wants to perform another handshake sequence. + if (scRet == SEC_I_RENEGOTIATE) + { + printf("Server requested renegotiate!\n"); + scRet = clientHandshakeLoop(FALSE, &ExtraBuffer); + if (scRet != SEC_E_OK) return scRet; + + if (ExtraBuffer.pvBuffer) // Move any "extra" data to the input buffer. + { + MoveMemory(pbRBuffer, ExtraBuffer.pvBuffer, ExtraBuffer.cbBuffer); + cbRBuffer = ExtraBuffer.cbBuffer; + } + } + } + while (scRet == SEC_E_INCOMPLETE_MESSAGE || cbRBuffer>0); + return SEC_E_OK; +} + +DWORD SChannelSocket::encryptSend(size_t len, SecPkgContext_StreamSizes Sizes) +// http://msdn.microsoft.com/en-us/library/aa375378(VS.85).aspx +// The encrypted message is encrypted in place, overwriting the original contents of its buffer. +{ + SECURITY_STATUS scRet; // unsigned long cbBuffer; // Size of the buffer, in bytes + SecBufferDesc Message; // unsigned long BufferType; // Type of the buffer (below) + SecBuffer Buffers[4]; // void SEC_FAR * pvBuffer; // Pointer to the buffer + DWORD cbMessage, cbData; + PBYTE pbMessage; + + pbMessage = pbWBuffer + Sizes.cbHeader; // Offset by "header size" + cbMessage = (DWORD)len; + + // Encrypt the HTTP request. + Buffers[0].pvBuffer = pbWBuffer; // Pointer to buffer 1 + Buffers[0].cbBuffer = Sizes.cbHeader; // length of header + Buffers[0].BufferType = SECBUFFER_STREAM_HEADER; // Type of the buffer + + Buffers[1].pvBuffer = pbMessage; // Pointer to buffer 2 + Buffers[1].cbBuffer = cbMessage; // length of the message + Buffers[1].BufferType = SECBUFFER_DATA; // Type of the buffer + + Buffers[2].pvBuffer = pbMessage + cbMessage; // Pointer to buffer 3 + Buffers[2].cbBuffer = Sizes.cbTrailer; // length of the trailor + Buffers[2].BufferType = SECBUFFER_STREAM_TRAILER; // Type of the buffer + + Buffers[3].pvBuffer = SECBUFFER_EMPTY; // Pointer to buffer 4 + Buffers[3].cbBuffer = SECBUFFER_EMPTY; // length of buffer 4 + Buffers[3].BufferType = SECBUFFER_EMPTY; // Type of the buffer 4 + + + Message.ulVersion = SECBUFFER_VERSION; // Version number + Message.cBuffers = 4; // Number of buffers - must contain four SecBuffer structures. + Message.pBuffers = Buffers; // Pointer to array of buffers + scRet = initializer.g_pSSPI->EncryptMessage(&hContext, 0, &Message, 0); // must contain four SecBuffer structures. + if (FAILED(scRet)) + { + printf("**** Error 0x%x returned by EncryptMessage\n", scRet); return scRet; + } + + // Send the encrypted data to the server. len flags + cbData = send(Client_Socket, (const PCHAR)pbWBuffer, Buffers[0].cbBuffer + Buffers[1].cbBuffer + Buffers[2].cbBuffer, 0); + + DEBUG("%d bytes of encrypted data sent\n", cbData); + return cbData; // send( Socket, pbIoBuffer, Sizes.cbHeader + strlen(pbMessage) + Sizes.cbTrailer, 0 ); +} + +void SChannelSocket::write(const char *p, size_t length) +{ + PBYTE pbIoBuffer; + SecPkgContext_StreamSizes Sizes; + initializer.g_pSSPI->QueryContextAttributes(&hContext, SECPKG_ATTR_STREAM_SIZES, &Sizes); + auto cbIoBufferLength = Sizes.cbHeader + Sizes.cbMaximumMessage + Sizes.cbTrailer; + memcpy(pbWBuffer+Sizes.cbHeader, p, length); + encryptSend(length, Sizes); +} + +size_t SChannelSocket::read(char *p, size_t length) +{ + unsigned int count = this->ready_bytes - this->offset_bytes; + // On emprty buffer read bytes from TLS + if (count <= 0) + { + SecPkgContext_StreamSizes Sizes; + size_t read_counter; + initializer.g_pSSPI->QueryContextAttributes(&hContext, SECPKG_ATTR_STREAM_SIZES, &Sizes); + auto cbIoBufferLength = Sizes.cbHeader + Sizes.cbMaximumMessage + Sizes.cbTrailer; + readDecrypt(cbIoBufferLength, &read_counter); + offset_bytes = 0; + count=ready_bytes = read_counter; + } + // Copy the bytes into the user buffer + if (count <= length) + { + memcpy(p, pbRBuffer + offset_bytes, count); + offset_bytes = ready_bytes = 0; + return count; + } + else + { + memcpy(p, pbRBuffer + offset_bytes, length); + offset_bytes+=length; + return length; + } +} + +int SChannelSocket::getSocket() +{ + return 0; +} + +SChannelSocket* SChannelSocket::create(const std::string& _serverCAPath, const std::string& _serverCAFile, const std::string& _clientCertificateFile) { + return new SChannelSocket(_serverCAPath, _serverCAFile, _clientCertificateFile); +} + +void SChannelSocket::logAndThrow(const std::string& host, const int port, const std::string& msg) { + DEBUG(msg.c_str()); + throw TransportException(host, port, msg); +} + +void SChannelSocket::cleanup() +{ + // Free the server certificate context. + if (pRemoteCertContext) + { + CertFreeCertificateContext(pRemoteCertContext); + pRemoteCertContext = NULL; + } + + // Free SSPI context handle. + if (isContextInitialized) + { + initializer.g_pSSPI->DeleteSecurityContext(&hContext); + isContextInitialized = FALSE; + } + + // Free SSPI credentials handle. + if (isCredsInitialized) + { + initializer.g_pSSPI->FreeCredentialsHandle(&hCred); + isCredsInitialized = FALSE; + } + + // Close socket. + if (Client_Socket != INVALID_SOCKET) closesocket(Client_Socket); + + // Close "MY" certificate store. + if (hMyCertStore) CertCloseStore(hMyCertStore, 0); + hMyCertStore = NULL; + + if (recvBuff != NULL) LocalFree(recvBuff); + if (pbWBuffer != NULL) LocalFree(pbWBuffer); + if (pbRBuffer != NULL) LocalFree(pbRBuffer); + recvBuff = pbWBuffer = pbRBuffer = NULL; +} + +void SChannelSocket::displayWinSockError(DWORD ErrCode) +{ + LPSTR pszName = NULL; // http://www.sockets.com/err_lst1.htm#WSANO_DATA + + switch (ErrCode) // http://msdn.microsoft.com/en-us/library/ms740668(VS.85).aspx + { + case 10035: pszName = "WSAEWOULDBLOCK "; break; + case 10036: pszName = "WSAEINPROGRESS "; break; + case 10037: pszName = "WSAEALREADY "; break; + case 10038: pszName = "WSAENOTSOCK "; break; + case 10039: pszName = "WSAEDESTADDRREQ "; break; + case 10040: pszName = "WSAEMSGSIZE "; break; + case 10041: pszName = "WSAEPROTOTYPE "; break; + case 10042: pszName = "WSAENOPROTOOPT "; break; + case 10043: pszName = "WSAEPROTONOSUPPORT"; break; + case 10044: pszName = "WSAESOCKTNOSUPPORT"; break; + case 10045: pszName = "WSAEOPNOTSUPP "; break; + case 10046: pszName = "WSAEPFNOSUPPORT "; break; + case 10047: pszName = "WSAEAFNOSUPPORT "; break; + case 10048: pszName = "WSAEADDRINUSE "; break; + case 10049: pszName = "WSAEADDRNOTAVAIL "; break; + case 10050: pszName = "WSAENETDOWN "; break; + case 10051: pszName = "WSAENETUNREACH "; break; + case 10052: pszName = "WSAENETRESET "; break; + case 10053: pszName = "WSAECONNABORTED "; break; + case 10054: pszName = "WSAECONNRESET "; break; + case 10055: pszName = "WSAENOBUFS "; break; + case 10056: pszName = "WSAEISCONN "; break; + case 10057: pszName = "WSAENOTCONN "; break; + case 10058: pszName = "WSAESHUTDOWN "; break; + case 10059: pszName = "WSAETOOMANYREFS "; break; + case 10060: pszName = "WSAETIMEDOUT "; break; + case 10061: pszName = "WSAECONNREFUSED "; break; + case 10062: pszName = "WSAELOOP "; break; + case 10063: pszName = "WSAENAMETOOLONG "; break; + case 10064: pszName = "WSAEHOSTDOWN "; break; + case 10065: pszName = "WSAEHOSTUNREACH "; break; + case 10066: pszName = "WSAENOTEMPTY "; break; + case 10067: pszName = "WSAEPROCLIM "; break; + case 10068: pszName = "WSAEUSERS "; break; + case 10069: pszName = "WSAEDQUOT "; break; + case 10070: pszName = "WSAESTALE "; break; + case 10071: pszName = "WSAEREMOTE "; break; + case 10091: pszName = "WSASYSNOTREADY "; break; + case 10092: pszName = "WSAVERNOTSUPPORTED"; break; + case 10093: pszName = "WSANOTINITIALISED "; break; + case 11001: pszName = "WSAHOST_NOT_FOUND "; break; + case 11002: pszName = "WSATRY_AGAIN "; break; + case 11003: pszName = "WSANO_RECOVERY "; break; + case 11004: pszName = "WSANO_DATA "; break; + } + ERROR("Error 0x%x (%s)\n", ErrCode, pszName); +} + +void SChannelSocket::displayConnectionInfo() +{ + SECURITY_STATUS Status; + SecPkgContext_ConnectionInfo ConnectionInfo; + + Status = initializer.g_pSSPI->QueryContextAttributes(&hContext, SECPKG_ATTR_CONNECTION_INFO, (PVOID)&ConnectionInfo); + if (Status != SEC_E_OK) { printf("Error 0x%x querying connection info\n", Status); return; } + + printf("\n"); + + switch (ConnectionInfo.dwProtocol) + { + case SP_PROT_TLS1_CLIENT: + DEBUG("Protocol: TLS1\n"); + break; + + case SP_PROT_SSL3_CLIENT: + DEBUG("Protocol: SSL3\n"); + break; + + case SP_PROT_PCT1_CLIENT: + DEBUG("Protocol: PCT\n"); + break; + + case SP_PROT_SSL2_CLIENT: + DEBUG("Protocol: SSL2\n"); + break; + + default: + DEBUG("Protocol: 0x%x\n", ConnectionInfo.dwProtocol); + } + + switch (ConnectionInfo.aiCipher) + { + case CALG_RC4: + DEBUG("Cipher: RC4\n"); + break; + + case CALG_3DES: + DEBUG("Cipher: Triple DES\n"); + break; + + case CALG_RC2: + DEBUG("Cipher: RC2\n"); + break; + + case CALG_DES: + case CALG_CYLINK_MEK: + DEBUG("Cipher: DES\n"); + break; + + case CALG_SKIPJACK: + DEBUG("Cipher: Skipjack\n"); + break; + + default: + DEBUG("Cipher: 0x%x\n", ConnectionInfo.aiCipher); + } + + DEBUG("Cipher strength: %d\n", ConnectionInfo.dwCipherStrength); + + switch (ConnectionInfo.aiHash) + { + case CALG_MD5: + DEBUG("Hash: MD5\n"); + break; + + case CALG_SHA: + DEBUG("Hash: SHA\n"); + break; + + default: + DEBUG("Hash: 0x%x\n", ConnectionInfo.aiHash); + } + + DEBUG("Hash strength: %d\n", ConnectionInfo.dwHashStrength); + + switch (ConnectionInfo.aiExch) + { + case CALG_RSA_KEYX: + case CALG_RSA_SIGN: + DEBUG("Key exchange: RSA\n"); + break; + + case CALG_KEA_KEYX: + DEBUG("Key exchange: KEA\n"); + break; + + case CALG_DH_EPHEM: + DEBUG("Key exchange: DH Ephemeral\n"); + break; + + default: + DEBUG("Key exchange: 0x%x\n", ConnectionInfo.aiExch); + } + + DEBUG("Key exchange strength: %d\n", ConnectionInfo.dwExchStrength); +} + + + + diff --git a/src/hotrod/sys/SChannelSocket.h b/src/hotrod/sys/SChannelSocket.h new file mode 100644 index 00000000..a41e6af9 --- /dev/null +++ b/src/hotrod/sys/SChannelSocket.h @@ -0,0 +1,81 @@ +#ifndef ISPN_HOTROD_SYS_SCHANNELSOCKET_H +#define ISPN_HOTROD_SYS_SCHANNELSOCKET_H + +#define SECURITY_WIN32 + +#include +#include +#include +#include +#include +#include +#include + +namespace infinispan { + namespace hotrod { + namespace sys { + + class SChannelSocket : public infinispan::hotrod::sys::Socket + { + public: + SChannelSocket(const std::string& _serverCAPath, const std::string& _serverCAFile, const std::string& _clientCertificateFile); + virtual ~SChannelSocket(); + // Inherited via Socket + static SChannelSocket* create(const std::string& _serverCAPath, const std::string& _serverCAFile, const std::string& _clientCertificateFile); + + virtual void connect(const std::string & host, int port, int timeout) override; + virtual void close() override; + virtual void setTcpNoDelay(bool tcpNoDelay) override; + virtual void setTimeout(int timeout) override; + virtual size_t read(char *p, size_t n) override; + virtual void write(const char *p, size_t n) override; + virtual int getSocket() override; + + private: + static void displayWinSockError(DWORD ErrCode); + static void displayWinVerifyTrustError(DWORD Status); + static INT connectToServer(std::string host, int iPortNumber, SOCKET * pSocket); + static DWORD verifyServerCertificate(PCCERT_CONTEXT pServerCert, std::string host, DWORD dwCertFlags); + static void logAndThrow(const std::string& host, const int port, const std::string& msg); + + SECURITY_STATUS readDecrypt(const DWORD bufsize, size_t *read_counter); + DWORD encryptSend(size_t len, SecPkgContext_StreamSizes Sizes); + SECURITY_STATUS clientHandshakeLoop(BOOL doInitialRead, SecBuffer *pExtraData); + void getNewClientCredentials(); + SECURITY_STATUS performClientHandshake(std::string host, SecBuffer *pExtraData); + void cleanup(); + void displayConnectionInfo(); + + class SChannelInitializer + { + public: + SChannelInitializer(); + PSecurityFunctionTable g_pSSPI; + }; + + static SChannelInitializer initializer; + static HCERTSTORE hMyCertStore; + + PCCERT_CONTEXT pRemoteCertContext = NULL; + std::string m_serverCAPath; + std::string m_serverCAFile; + std::string m_clientCertificateFile; + SOCKET Client_Socket; + CredHandle hCred; + struct _SecHandle hCtxt; + CtxtHandle hContext; + PBYTE pbRBuffer; + PBYTE pbWBuffer; + PBYTE recvBuff; + DWORD cbRBuffer; + unsigned int ready_bytes=0; + unsigned int offset_bytes=0; + boolean isContextInitialized = false; + boolean isCredsInitialized = false; + boolean onlyVerified = false; + }; + + } + } +} +#endif /* ISPN_HOTROD_SYS_SCHANNELSOCKET_H */ diff --git a/test/SimpleSChannel.cpp b/test/SimpleSChannel.cpp new file mode 100644 index 00000000..7e0b371d --- /dev/null +++ b/test/SimpleSChannel.cpp @@ -0,0 +1,45 @@ +#include "infinispan/hotrod/ConfigurationBuilder.h" +#include "infinispan/hotrod/RemoteCacheManager.h" +#include "infinispan/hotrod/RemoteCache.h" +#include "infinispan/hotrod/Version.h" + +#include "infinispan/hotrod/JBasicMarshaller.h" +#include +#include +#include + +// For CTest: return 0 if all tests pass, non-zero otherwise. + +using namespace infinispan::hotrod; +int main(int argc, char** argv) { + std::cout << "SChannel Test" << std::endl; + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " ca_path [client_ca_file]" << std::endl; + return 1; + } + ConfigurationBuilder builder; + builder.addServer().host("127.0.0.1").port(11222).protocolVersion(Configuration::PROTOCOL_VERSION_24); + builder.ssl().enable().serverCAFile(argv[1]); + if (argc > 2) { + std::cout << "Using " << argv[1] << " as certificate" << std::endl; + builder.ssl().clientCertificateFile(argv[2]); + } + RemoteCacheManager cacheManager(builder.build(), false); + BasicMarshaller *km = new BasicMarshaller(); + BasicMarshaller *vm = new BasicMarshaller(); + RemoteCache cache = cacheManager.getCache(km, + &Marshaller::destroy, vm, &Marshaller::destroy); + cacheManager.start(); + cache.clear(); + std::string k1("key"); + std::string v1("value"); + cache.put(k1, v1); + std::unique_ptr rv(cache.get(k1)); + if (rv->compare(v1)) { + std::cerr << "get/put fail for " << k1 << " got " << *rv << " expected " << v1 << std::endl; + return 1; + } + cacheManager.stop(); + return 0; +} +