12
12
#include " diagnosticsipc.h"
13
13
#include " processdescriptor.h"
14
14
15
- IpcStream::DiagnosticsIpc::DiagnosticsIpc (const int serverSocket, sockaddr_un *const pServerAddress) :
15
+ #if __GNUC__
16
+ #include < poll.h>
17
+ #else
18
+ #include < sys/poll.h>
19
+ #endif // __GNUC__
20
+
21
+ IpcStream::DiagnosticsIpc::DiagnosticsIpc (const int serverSocket, sockaddr_un *const pServerAddress, ConnectionMode mode) :
22
+ mode(mode),
16
23
_serverSocket(serverSocket),
17
24
_pServerAddress(new sockaddr_un),
18
- _isClosed(false )
25
+ _isClosed(false ),
26
+ _isListening(false )
19
27
{
20
28
_ASSERTE (_pServerAddress != nullptr );
21
- _ASSERTE (_serverSocket != -1 );
22
29
_ASSERTE (pServerAddress != nullptr );
23
30
24
31
if (_pServerAddress == nullptr || pServerAddress == nullptr )
@@ -32,24 +39,8 @@ IpcStream::DiagnosticsIpc::~DiagnosticsIpc()
32
39
delete _pServerAddress;
33
40
}
34
41
35
- IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create (const char *const pIpcName, ErrorCallback callback)
42
+ IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create (const char *const pIpcName, ConnectionMode mode, ErrorCallback callback)
36
43
{
37
- #ifdef __APPLE__
38
- mode_t prev_mask = umask (~(S_IRUSR | S_IWUSR)); // This will set the default permission bit to 600
39
- #endif // __APPLE__
40
-
41
- const int serverSocket = ::socket (AF_UNIX, SOCK_STREAM, 0 );
42
- if (serverSocket == -1 )
43
- {
44
- if (callback != nullptr )
45
- callback (strerror (errno), errno);
46
- #ifdef __APPLE__
47
- umask (prev_mask);
48
- #endif // __APPLE__
49
- _ASSERTE (!" Failed to create diagnostics IPC socket." );
50
- return nullptr ;
51
- }
52
-
53
44
sockaddr_un serverAddress{};
54
45
serverAddress.sun_family = AF_UNIX;
55
46
@@ -71,6 +62,24 @@ IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create(const char *const p
71
62
" socket" );
72
63
}
73
64
65
+ if (mode == ConnectionMode::CLIENT)
66
+ return new IpcStream::DiagnosticsIpc (-1 , &serverAddress, ConnectionMode::CLIENT);
67
+
68
+ #ifdef __APPLE__
69
+ mode_t prev_mask = umask (~(S_IRUSR | S_IWUSR)); // This will set the default permission bit to 600
70
+ #endif // __APPLE__
71
+
72
+ const int serverSocket = ::socket (AF_UNIX, SOCK_STREAM, 0 );
73
+ if (serverSocket == -1 )
74
+ {
75
+ if (callback != nullptr )
76
+ callback (strerror (errno), errno);
77
+ #ifdef __APPLE__
78
+ umask (prev_mask);
79
+ #endif // __APPLE__
80
+ _ASSERTE (!" Failed to create diagnostics IPC socket." );
81
+ return nullptr ;
82
+ }
74
83
75
84
#ifndef __APPLE__
76
85
if (fchmod (serverSocket, S_IRUSR | S_IWUSR) == -1 )
@@ -99,33 +108,52 @@ IpcStream::DiagnosticsIpc *IpcStream::DiagnosticsIpc::Create(const char *const p
99
108
return nullptr ;
100
109
}
101
110
102
- const int fSuccessfulListen = ::listen (serverSocket, /* backlog */ 255 );
111
+ #ifdef __APPLE__
112
+ umask (prev_mask);
113
+ #endif // __APPLE__
114
+
115
+ return new IpcStream::DiagnosticsIpc (serverSocket, &serverAddress, mode);
116
+ }
117
+
118
+ bool IpcStream::DiagnosticsIpc::Listen (ErrorCallback callback)
119
+ {
120
+ _ASSERTE (mode == ConnectionMode::SERVER);
121
+ if (mode != ConnectionMode::SERVER)
122
+ {
123
+ if (callback != nullptr )
124
+ callback (" Cannot call Listen on a client connection" , -1 );
125
+ return false ;
126
+ }
127
+
128
+ if (_isListening)
129
+ return true ;
130
+
131
+ const int fSuccessfulListen = ::listen (_serverSocket, /* backlog */ 255 );
103
132
if (fSuccessfulListen == -1 )
104
133
{
105
134
if (callback != nullptr )
106
135
callback (strerror (errno), errno);
107
136
_ASSERTE (fSuccessfulListen != -1 );
108
137
109
- const int fSuccessUnlink = ::unlink (serverAddress. sun_path );
138
+ const int fSuccessUnlink = ::unlink (_pServerAddress-> sun_path );
110
139
_ASSERTE (fSuccessUnlink != -1 );
111
140
112
- const int fSuccessClose = ::close (serverSocket );
141
+ const int fSuccessClose = ::close (_serverSocket );
113
142
_ASSERTE (fSuccessClose != -1 );
114
- #ifdef __APPLE__
115
- umask (prev_mask);
116
- #endif // __APPLE__
117
- return nullptr ;
143
+ return false ;
144
+ }
145
+ else
146
+ {
147
+ _isListening = true ;
148
+ return true ;
118
149
}
119
-
120
- #ifdef __APPLE__
121
- umask (prev_mask);
122
- #endif // __APPLE__
123
-
124
- return new IpcStream::DiagnosticsIpc (serverSocket, &serverAddress);
125
150
}
126
151
127
- IpcStream *IpcStream::DiagnosticsIpc::Accept (ErrorCallback callback) const
152
+ IpcStream *IpcStream::DiagnosticsIpc::Accept (ErrorCallback callback)
128
153
{
154
+ _ASSERTE (mode == ConnectionMode::SERVER);
155
+ _ASSERTE (_isListening);
156
+
129
157
sockaddr_un from;
130
158
socklen_t fromlen = sizeof (from);
131
159
const int clientSocket = ::accept (_serverSocket, (sockaddr *)&from, &fromlen);
@@ -136,7 +164,114 @@ IpcStream *IpcStream::DiagnosticsIpc::Accept(ErrorCallback callback) const
136
164
return nullptr ;
137
165
}
138
166
139
- return new IpcStream (clientSocket);
167
+ return new IpcStream (clientSocket, mode);
168
+ }
169
+
170
+ IpcStream *IpcStream::DiagnosticsIpc::Connect (ErrorCallback callback)
171
+ {
172
+ _ASSERTE (mode == ConnectionMode::CLIENT);
173
+
174
+ sockaddr_un clientAddress{};
175
+ clientAddress.sun_family = AF_UNIX;
176
+ const int clientSocket = ::socket (AF_UNIX, SOCK_STREAM, 0 );
177
+ if (clientSocket == -1 )
178
+ {
179
+ if (callback != nullptr )
180
+ callback (strerror (errno), errno);
181
+ return nullptr ;
182
+ }
183
+
184
+ // We don't expect this to block since this is a Unix Domain Socket. `connect` may block until the
185
+ // TCP handshake is complete for TCP/IP sockets, but UDS don't use TCP. `connect` will return even if
186
+ // the server hasn't called `accept`.
187
+ if (::connect (clientSocket, (struct sockaddr *)_pServerAddress, sizeof (*_pServerAddress)) < 0 )
188
+ {
189
+ if (callback != nullptr )
190
+ callback (strerror (errno), errno);
191
+ return nullptr ;
192
+ }
193
+
194
+ return new IpcStream (clientSocket, ConnectionMode::CLIENT);
195
+ }
196
+
197
+ int32_t IpcStream::DiagnosticsIpc::Poll (IpcPollHandle *rgIpcPollHandles, uint32_t nHandles, int32_t timeoutMs, ErrorCallback callback)
198
+ {
199
+ // prepare the pollfd structs
200
+ pollfd *pollfds = new pollfd[nHandles];
201
+ for (uint32_t i = 0 ; i < nHandles; i++)
202
+ {
203
+ rgIpcPollHandles[i].revents = 0 ; // ignore any values in revents
204
+ int fd = -1 ;
205
+ if (rgIpcPollHandles[i].pIpc != nullptr )
206
+ {
207
+ // SERVER
208
+ _ASSERTE (rgIpcPollHandles[i].pIpc ->mode == ConnectionMode::SERVER);
209
+ fd = rgIpcPollHandles[i].pIpc ->_serverSocket ;
210
+ }
211
+ else
212
+ {
213
+ // CLIENT
214
+ _ASSERTE (rgIpcPollHandles[i].pStream != nullptr );
215
+ fd = rgIpcPollHandles[i].pStream ->_clientSocket ;
216
+ }
217
+
218
+ pollfds[i].fd = fd;
219
+ pollfds[i].events = POLLIN;
220
+ }
221
+
222
+ int retval = poll (pollfds, nHandles, timeoutMs);
223
+
224
+ // Check results
225
+ if (retval < 0 )
226
+ {
227
+ for (uint32_t i = 0 ; i < nHandles; i++)
228
+ {
229
+ if ((pollfds[i].revents & POLLERR) && callback != nullptr )
230
+ callback (strerror (errno), errno);
231
+ rgIpcPollHandles[i].revents = (uint8_t )PollEvents::ERR;
232
+ }
233
+ delete[] pollfds;
234
+ return -1 ;
235
+ }
236
+ else if (retval == 0 )
237
+ {
238
+ // we timed out
239
+ delete[] pollfds;
240
+ return 0 ;
241
+ }
242
+
243
+ for (uint32_t i = 0 ; i < nHandles; i++)
244
+ {
245
+ if (pollfds[i].revents != 0 )
246
+ {
247
+ // error check FIRST
248
+ if (pollfds[i].revents & POLLHUP)
249
+ {
250
+ // check for hangup first because a closed socket
251
+ // will technically meet the requirements for POLLIN
252
+ // i.e., a call to recv/read won't block
253
+ rgIpcPollHandles[i].revents = (uint8_t )PollEvents::HANGUP;
254
+ delete[] pollfds;
255
+ return -1 ;
256
+ }
257
+ else if ((pollfds[i].revents & (POLLERR|POLLNVAL)))
258
+ {
259
+ if (callback != nullptr )
260
+ callback (" Poll error" , (uint32_t )pollfds[i].revents );
261
+ rgIpcPollHandles[i].revents = (uint8_t )PollEvents::ERR;
262
+ delete[] pollfds;
263
+ return -1 ;
264
+ }
265
+ else if (pollfds[i].revents & POLLIN)
266
+ {
267
+ rgIpcPollHandles[i].revents = (uint8_t )PollEvents::SIGNALED;
268
+ break ;
269
+ }
270
+ }
271
+ }
272
+
273
+ delete[] pollfds;
274
+ return 1 ;
140
275
}
141
276
142
277
void IpcStream::DiagnosticsIpc::Close (ErrorCallback callback)
@@ -172,45 +307,101 @@ void IpcStream::DiagnosticsIpc::Unlink(ErrorCallback callback)
172
307
}
173
308
174
309
IpcStream::~IpcStream ()
310
+ {
311
+ Close ();
312
+ }
313
+
314
+ void IpcStream::Close (ErrorCallback)
175
315
{
176
316
if (_clientSocket != -1 )
177
317
{
178
318
Flush ();
179
319
180
320
const int fSuccessClose = ::close (_clientSocket);
181
321
_ASSERTE (fSuccessClose != -1 );
322
+ _clientSocket = -1 ;
182
323
}
183
324
}
184
325
185
- bool IpcStream::Read (void *lpBuffer, const uint32_t nBytesToRead, uint32_t &nBytesRead) const
326
+ bool IpcStream::Read (void *lpBuffer, const uint32_t nBytesToRead, uint32_t &nBytesRead, const int32_t timeoutMs)
186
327
{
187
328
_ASSERTE (lpBuffer != nullptr );
188
329
189
- const ssize_t ssize = ::recv (_clientSocket, lpBuffer, nBytesToRead, 0 );
190
- const bool fSuccess = ssize != -1 ;
330
+ if (timeoutMs != InfiniteTimeout)
331
+ {
332
+ pollfd pfd;
333
+ pfd.fd = _clientSocket;
334
+ pfd.events = POLLIN;
335
+ int retval = poll (&pfd, 1 , timeoutMs);
336
+ if (retval <= 0 || pfd.revents != POLLIN)
337
+ {
338
+ // timeout or error
339
+ return false ;
340
+ }
341
+ // else fallthrough
342
+ }
343
+
344
+ uint8_t *lpBufferCursor = (uint8_t *)lpBuffer;
345
+ ssize_t currentBytesRead = 0 ;
346
+ ssize_t totalBytesRead = 0 ;
347
+ bool fSuccess = true ;
348
+ while (fSuccess && nBytesToRead - totalBytesRead > 0 )
349
+ {
350
+ currentBytesRead = ::recv (_clientSocket, lpBufferCursor, nBytesToRead - totalBytesRead, 0 );
351
+ fSuccess = currentBytesRead != 0 ;
352
+ if (!fSuccess )
353
+ break ;
354
+ totalBytesRead += currentBytesRead;
355
+ lpBufferCursor += currentBytesRead;
356
+ }
191
357
192
358
if (!fSuccess )
193
359
{
194
360
// TODO: Add error handling.
195
361
}
196
362
197
- nBytesRead = static_cast <uint32_t >(ssize );
363
+ nBytesRead = static_cast <uint32_t >(totalBytesRead );
198
364
return fSuccess ;
199
365
}
200
366
201
- bool IpcStream::Write (const void *lpBuffer, const uint32_t nBytesToWrite, uint32_t &nBytesWritten) const
367
+ bool IpcStream::Write (const void *lpBuffer, const uint32_t nBytesToWrite, uint32_t &nBytesWritten, const int32_t timeoutMs)
202
368
{
203
369
_ASSERTE (lpBuffer != nullptr );
204
370
205
- const ssize_t ssize = ::send (_clientSocket, lpBuffer, nBytesToWrite, 0 );
206
- const bool fSuccess = ssize != -1 ;
371
+ if (timeoutMs != InfiniteTimeout)
372
+ {
373
+ pollfd pfd;
374
+ pfd.fd = _clientSocket;
375
+ pfd.events = POLLOUT;
376
+ int retval = poll (&pfd, 1 , timeoutMs);
377
+ if (retval <= 0 || pfd.revents != POLLOUT)
378
+ {
379
+ // timeout or error
380
+ return false ;
381
+ }
382
+ // else fallthrough
383
+ }
384
+
385
+ uint8_t *lpBufferCursor = (uint8_t *)lpBuffer;
386
+ ssize_t currentBytesWritten = 0 ;
387
+ ssize_t totalBytesWritten = 0 ;
388
+ bool fSuccess = true ;
389
+ while (fSuccess && nBytesToWrite - totalBytesWritten > 0 )
390
+ {
391
+ currentBytesWritten = ::send (_clientSocket, lpBufferCursor, nBytesToWrite - totalBytesWritten, 0 );
392
+ fSuccess = currentBytesWritten != -1 ;
393
+ if (!fSuccess )
394
+ break ;
395
+ lpBufferCursor += currentBytesWritten;
396
+ totalBytesWritten += currentBytesWritten;
397
+ }
207
398
208
399
if (!fSuccess )
209
400
{
210
401
// TODO: Add error handling.
211
402
}
212
403
213
- nBytesWritten = static_cast <uint32_t >(ssize );
404
+ nBytesWritten = static_cast <uint32_t >(totalBytesWritten );
214
405
return fSuccess ;
215
406
}
216
407
0 commit comments