Skip to content

Commit

Permalink
Add Reverse Diagnostics Server (#33307)
Browse files Browse the repository at this point in the history
* Add Advertise IPC Command
* Update diagnostics server to use both reverse and traditional modes
* Change DOTNET_DiagnosticsServerAddress to DOTNET_DiagnosticsMonitorAddress and only use for reverse connection
* Add IpcStreamFactory abstraction
* IpcStreamFactory::Poll is now more similar to the poll API from Linux
* IpcPollHandle struct is used to abstract listening for client and server connections
* use overlapped io for all io on windows
* Add ConnectionState abstraction
* Implement timeout read/write
  • Loading branch information
John Salem committed Apr 20, 2020
1 parent 779588a commit 629dba5
Show file tree
Hide file tree
Showing 15 changed files with 1,587 additions and 96 deletions.
277 changes: 234 additions & 43 deletions src/coreclr/src/debug/debug-pal/unix/diagnosticsipc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@
#include "diagnosticsipc.h"
#include "processdescriptor.h"

IpcStream::DiagnosticsIpc::DiagnosticsIpc(const int serverSocket, sockaddr_un *const pServerAddress) :
#if __GNUC__
#include <poll.h>
#else
#include <sys/poll.h>
#endif // __GNUC__

IpcStream::DiagnosticsIpc::DiagnosticsIpc(const int serverSocket, sockaddr_un *const pServerAddress, ConnectionMode mode) :
mode(mode),
_serverSocket(serverSocket),
_pServerAddress(new sockaddr_un),
_isClosed(false)
_isClosed(false),
_isListening(false)
{
_ASSERTE(_pServerAddress != nullptr);
_ASSERTE(_serverSocket != -1);
_ASSERTE(pServerAddress != nullptr);

if (_pServerAddress == nullptr || pServerAddress == nullptr)
Expand All @@ -32,24 +39,8 @@ IpcStream::DiagnosticsIpc::~DiagnosticsIpc()
delete _pServerAddress;
}

IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create(const char *const pIpcName, ErrorCallback callback)
IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create(const char *const pIpcName, ConnectionMode mode, ErrorCallback callback)
{
#ifdef __APPLE__
mode_t prev_mask = umask(~(S_IRUSR | S_IWUSR)); // This will set the default permission bit to 600
#endif // __APPLE__

const int serverSocket = ::socket(AF_UNIX, SOCK_STREAM, 0);
if (serverSocket == -1)
{
if (callback != nullptr)
callback(strerror(errno), errno);
#ifdef __APPLE__
umask(prev_mask);
#endif // __APPLE__
_ASSERTE(!"Failed to create diagnostics IPC socket.");
return nullptr;
}

sockaddr_un serverAddress{};
serverAddress.sun_family = AF_UNIX;

Expand All @@ -71,6 +62,24 @@ IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create(const char *const p
"socket");
}

if (mode == ConnectionMode::CLIENT)
return new IpcStream::DiagnosticsIpc(-1, &serverAddress, ConnectionMode::CLIENT);

#ifdef __APPLE__
mode_t prev_mask = umask(~(S_IRUSR | S_IWUSR)); // This will set the default permission bit to 600
#endif // __APPLE__

const int serverSocket = ::socket(AF_UNIX, SOCK_STREAM, 0);
if (serverSocket == -1)
{
if (callback != nullptr)
callback(strerror(errno), errno);
#ifdef __APPLE__
umask(prev_mask);
#endif // __APPLE__
_ASSERTE(!"Failed to create diagnostics IPC socket.");
return nullptr;
}

#ifndef __APPLE__
if (fchmod(serverSocket, S_IRUSR | S_IWUSR) == -1)
Expand Down Expand Up @@ -99,33 +108,52 @@ IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create(const char *const p
return nullptr;
}

const int fSuccessfulListen = ::listen(serverSocket, /* backlog */ 255);
#ifdef __APPLE__
umask(prev_mask);
#endif // __APPLE__

return new IpcStream::DiagnosticsIpc(serverSocket, &serverAddress, mode);
}

bool IpcStream::DiagnosticsIpc::Listen(ErrorCallback callback)
{
_ASSERTE(mode == ConnectionMode::SERVER);
if (mode != ConnectionMode::SERVER)
{
if (callback != nullptr)
callback("Cannot call Listen on a client connection", -1);
return false;
}

if (_isListening)
return true;

const int fSuccessfulListen = ::listen(_serverSocket, /* backlog */ 255);
if (fSuccessfulListen == -1)
{
if (callback != nullptr)
callback(strerror(errno), errno);
_ASSERTE(fSuccessfulListen != -1);

const int fSuccessUnlink = ::unlink(serverAddress.sun_path);
const int fSuccessUnlink = ::unlink(_pServerAddress->sun_path);
_ASSERTE(fSuccessUnlink != -1);

const int fSuccessClose = ::close(serverSocket);
const int fSuccessClose = ::close(_serverSocket);
_ASSERTE(fSuccessClose != -1);
#ifdef __APPLE__
umask(prev_mask);
#endif // __APPLE__
return nullptr;
return false;
}
else
{
_isListening = true;
return true;
}

#ifdef __APPLE__
umask(prev_mask);
#endif // __APPLE__

return new IpcStream::DiagnosticsIpc(serverSocket, &serverAddress);
}

IpcStream *IpcStream::DiagnosticsIpc::Accept(ErrorCallback callback) const
IpcStream *IpcStream::DiagnosticsIpc::Accept(ErrorCallback callback)
{
_ASSERTE(mode == ConnectionMode::SERVER);
_ASSERTE(_isListening);

sockaddr_un from;
socklen_t fromlen = sizeof(from);
const int clientSocket = ::accept(_serverSocket, (sockaddr *)&from, &fromlen);
Expand All @@ -136,7 +164,114 @@ IpcStream *IpcStream::DiagnosticsIpc::Accept(ErrorCallback callback) const
return nullptr;
}

return new IpcStream(clientSocket);
return new IpcStream(clientSocket, mode);
}

IpcStream *IpcStream::DiagnosticsIpc::Connect(ErrorCallback callback)
{
_ASSERTE(mode == ConnectionMode::CLIENT);

sockaddr_un clientAddress{};
clientAddress.sun_family = AF_UNIX;
const int clientSocket = ::socket(AF_UNIX, SOCK_STREAM, 0);
if (clientSocket == -1)
{
if (callback != nullptr)
callback(strerror(errno), errno);
return nullptr;
}

// We don't expect this to block since this is a Unix Domain Socket. `connect` may block until the
// TCP handshake is complete for TCP/IP sockets, but UDS don't use TCP. `connect` will return even if
// the server hasn't called `accept`.
if (::connect(clientSocket, (struct sockaddr *)_pServerAddress, sizeof(*_pServerAddress)) < 0)
{
if (callback != nullptr)
callback(strerror(errno), errno);
return nullptr;
}

return new IpcStream(clientSocket, ConnectionMode::CLIENT);
}

int32_t IpcStream::DiagnosticsIpc::Poll(IpcPollHandle *rgIpcPollHandles, uint32_t nHandles, int32_t timeoutMs, ErrorCallback callback)
{
// prepare the pollfd structs
pollfd *pollfds = new pollfd[nHandles];
for (uint32_t i = 0; i < nHandles; i++)
{
rgIpcPollHandles[i].revents = 0; // ignore any values in revents
int fd = -1;
if (rgIpcPollHandles[i].pIpc != nullptr)
{
// SERVER
_ASSERTE(rgIpcPollHandles[i].pIpc->mode == ConnectionMode::SERVER);
fd = rgIpcPollHandles[i].pIpc->_serverSocket;
}
else
{
// CLIENT
_ASSERTE(rgIpcPollHandles[i].pStream != nullptr);
fd = rgIpcPollHandles[i].pStream->_clientSocket;
}

pollfds[i].fd = fd;
pollfds[i].events = POLLIN;
}

int retval = poll(pollfds, nHandles, timeoutMs);

// Check results
if (retval < 0)
{
for (uint32_t i = 0; i < nHandles; i++)
{
if ((pollfds[i].revents & POLLERR) && callback != nullptr)
callback(strerror(errno), errno);
rgIpcPollHandles[i].revents = (uint8_t)PollEvents::ERR;
}
delete[] pollfds;
return -1;
}
else if (retval == 0)
{
// we timed out
delete[] pollfds;
return 0;
}

for (uint32_t i = 0; i < nHandles; i++)
{
if (pollfds[i].revents != 0)
{
// error check FIRST
if (pollfds[i].revents & POLLHUP)
{
// check for hangup first because a closed socket
// will technically meet the requirements for POLLIN
// i.e., a call to recv/read won't block
rgIpcPollHandles[i].revents = (uint8_t)PollEvents::HANGUP;
delete[] pollfds;
return -1;
}
else if ((pollfds[i].revents & (POLLERR|POLLNVAL)))
{
if (callback != nullptr)
callback("Poll error", (uint32_t)pollfds[i].revents);
rgIpcPollHandles[i].revents = (uint8_t)PollEvents::ERR;
delete[] pollfds;
return -1;
}
else if (pollfds[i].revents & POLLIN)
{
rgIpcPollHandles[i].revents = (uint8_t)PollEvents::SIGNALED;
break;
}
}
}

delete[] pollfds;
return 1;
}

void IpcStream::DiagnosticsIpc::Close(ErrorCallback callback)
Expand Down Expand Up @@ -172,45 +307,101 @@ void IpcStream::DiagnosticsIpc::Unlink(ErrorCallback callback)
}

IpcStream::~IpcStream()
{
Close();
}

void IpcStream::Close(ErrorCallback)
{
if (_clientSocket != -1)
{
Flush();

const int fSuccessClose = ::close(_clientSocket);
_ASSERTE(fSuccessClose != -1);
_clientSocket = -1;
}
}

bool IpcStream::Read(void *lpBuffer, const uint32_t nBytesToRead, uint32_t &nBytesRead) const
bool IpcStream::Read(void *lpBuffer, const uint32_t nBytesToRead, uint32_t &nBytesRead, const int32_t timeoutMs)
{
_ASSERTE(lpBuffer != nullptr);

const ssize_t ssize = ::recv(_clientSocket, lpBuffer, nBytesToRead, 0);
const bool fSuccess = ssize != -1;
if (timeoutMs != InfiniteTimeout)
{
pollfd pfd;
pfd.fd = _clientSocket;
pfd.events = POLLIN;
int retval = poll(&pfd, 1, timeoutMs);
if (retval <= 0 || pfd.revents != POLLIN)
{
// timeout or error
return false;
}
// else fallthrough
}

uint8_t *lpBufferCursor = (uint8_t*)lpBuffer;
ssize_t currentBytesRead = 0;
ssize_t totalBytesRead = 0;
bool fSuccess = true;
while (fSuccess && nBytesToRead - totalBytesRead > 0)
{
currentBytesRead = ::recv(_clientSocket, lpBufferCursor, nBytesToRead - totalBytesRead, 0);
fSuccess = currentBytesRead != 0;
if (!fSuccess)
break;
totalBytesRead += currentBytesRead;
lpBufferCursor += currentBytesRead;
}

if (!fSuccess)
{
// TODO: Add error handling.
}

nBytesRead = static_cast<uint32_t>(ssize);
nBytesRead = static_cast<uint32_t>(totalBytesRead);
return fSuccess;
}

bool IpcStream::Write(const void *lpBuffer, const uint32_t nBytesToWrite, uint32_t &nBytesWritten) const
bool IpcStream::Write(const void *lpBuffer, const uint32_t nBytesToWrite, uint32_t &nBytesWritten, const int32_t timeoutMs)
{
_ASSERTE(lpBuffer != nullptr);

const ssize_t ssize = ::send(_clientSocket, lpBuffer, nBytesToWrite, 0);
const bool fSuccess = ssize != -1;
if (timeoutMs != InfiniteTimeout)
{
pollfd pfd;
pfd.fd = _clientSocket;
pfd.events = POLLOUT;
int retval = poll(&pfd, 1, timeoutMs);
if (retval <= 0 || pfd.revents != POLLOUT)
{
// timeout or error
return false;
}
// else fallthrough
}

uint8_t *lpBufferCursor = (uint8_t*)lpBuffer;
ssize_t currentBytesWritten = 0;
ssize_t totalBytesWritten = 0;
bool fSuccess = true;
while (fSuccess && nBytesToWrite - totalBytesWritten > 0)
{
currentBytesWritten = ::send(_clientSocket, lpBufferCursor, nBytesToWrite - totalBytesWritten, 0);
fSuccess = currentBytesWritten != -1;
if (!fSuccess)
break;
lpBufferCursor += currentBytesWritten;
totalBytesWritten += currentBytesWritten;
}

if (!fSuccess)
{
// TODO: Add error handling.
}

nBytesWritten = static_cast<uint32_t>(ssize);
nBytesWritten = static_cast<uint32_t>(totalBytesWritten);
return fSuccess;
}

Expand Down
Loading

0 comments on commit 629dba5

Please sign in to comment.