Skip to content

Commit 325da4e

Browse files
committed
net:tcp-server: add a SRW lock around m_writeCallback
1 parent f2b3312 commit 325da4e

File tree

2 files changed

+57
-33
lines changed

2 files changed

+57
-33
lines changed

code/components/net-tcp-server/include/UvTcpServer.h

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <uv.h>
1111

1212
#include <memory>
13+
#include <shared_mutex>
1314

1415
#include "TcpServer.h"
1516

@@ -28,6 +29,8 @@ class UvTcpServerStream : public TcpServerStream
2829

2930
std::unique_ptr<uv_async_t> m_writeCallback;
3031

32+
std::shared_mutex m_writeCallbackMutex;
33+
3134
tbb::concurrent_queue<std::function<void()>> m_pendingRequests;
3235

3336
std::vector<char> m_readBuffer;

code/components/net-tcp-server/src/UvTcpServer.cpp

+54-33
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,17 @@ void UvTcpServerStream::CloseClient()
116116
{
117117
if (m_client.get())
118118
{
119+
decltype(m_writeCallback) writeCallback;
120+
121+
{
122+
std::unique_lock<std::shared_mutex> lock(m_writeCallbackMutex);
123+
writeCallback = std::move(m_writeCallback);
124+
}
125+
119126
uv_read_stop(reinterpret_cast<uv_stream_t*>(m_client.get()));
120127

121128
UvClose(std::move(m_client));
122-
UvClose(std::move(m_writeCallback));
129+
UvClose(std::move(writeCallback));
123130
}
124131
}
125132

@@ -130,10 +137,14 @@ bool UvTcpServerStream::Accept(std::unique_ptr<uv_tcp_t>&& client)
130137
uv_tcp_nodelay(m_client.get(), true);
131138

132139
// initialize a write callback handle
133-
m_writeCallback = std::make_unique<uv_async_t>();
134-
m_writeCallback->data = this;
140+
{
141+
std::unique_lock<std::shared_mutex> lock(m_writeCallbackMutex);
142+
m_writeCallback = std::make_unique<uv_async_t>();
143+
144+
m_writeCallback->data = this;
135145

136-
uv_async_init(m_server->GetManager()->GetLoop(), m_writeCallback.get(), UvCallback<uv_async_t, UvTcpServerStream, &UvTcpServerStream::HandlePendingWrites>);
146+
uv_async_init(m_server->GetManager()->GetLoop(), m_writeCallback.get(), UvCallback<uv_async_t, UvTcpServerStream, &UvTcpServerStream::HandlePendingWrites>);
147+
}
137148

138149
// accept
139150
int result = uv_accept(reinterpret_cast<uv_stream_t*>(m_server->GetServer()),
@@ -204,46 +215,52 @@ void UvTcpServerStream::Write(const std::vector<uint8_t>& data)
204215
fwRefContainer<UvTcpServerStream> stream;
205216
};
206217

207-
if (m_writeCallback)
208218
{
209-
// prepare a write request
210-
UvWriteReq* writeReq = new UvWriteReq;
211-
writeReq->sendData = data;
212-
writeReq->buffer.base = reinterpret_cast<char*>(&writeReq->sendData[0]);
213-
writeReq->buffer.len = writeReq->sendData.size();
214-
writeReq->stream = this;
215-
216-
writeReq->write.data = writeReq;
219+
std::shared_lock<std::shared_mutex> lock(m_writeCallbackMutex);
217220

218-
// submit the write request
219-
m_pendingRequests.push([=]()
221+
if (m_writeCallback)
220222
{
221-
if (!m_client)
222-
{
223-
return;
224-
}
223+
// prepare a write request
224+
UvWriteReq* writeReq = new UvWriteReq;
225+
writeReq->sendData = data;
226+
writeReq->buffer.base = reinterpret_cast<char*>(&writeReq->sendData[0]);
227+
writeReq->buffer.len = writeReq->sendData.size();
228+
writeReq->stream = this;
225229

226-
// send the write request
227-
uv_write(&writeReq->write, reinterpret_cast<uv_stream_t*>(m_client.get()), &writeReq->buffer, 1, [](uv_write_t * write, int status)
228-
{
229-
UvWriteReq* req = reinterpret_cast<UvWriteReq*>(write->data);
230+
writeReq->write.data = writeReq;
230231

231-
if (status < 0)
232+
// submit the write request
233+
m_pendingRequests.push([this, writeReq]()
234+
{
235+
if (!m_client)
232236
{
233-
//trace("write to %s failed - %s\n", req->stream->GetPeerAddress().ToString().c_str(), uv_strerror(status));
237+
return;
234238
}
235239

236-
delete req;
240+
// send the write request
241+
uv_write(&writeReq->write, reinterpret_cast<uv_stream_t*>(m_client.get()), &writeReq->buffer, 1, [](uv_write_t* write, int status)
242+
{
243+
UvWriteReq* req = reinterpret_cast<UvWriteReq*>(write->data);
244+
245+
if (status < 0)
246+
{
247+
//trace("write to %s failed - %s\n", req->stream->GetPeerAddress().ToString().c_str(), uv_strerror(status));
248+
}
249+
250+
delete req;
251+
});
237252
});
238-
});
239253

240-
// wake the callback
241-
uv_async_send(m_writeCallback.get());
254+
// wake the callback
255+
uv_async_send(m_writeCallback.get());
256+
}
242257
}
243258
}
244259

245260
void UvTcpServerStream::ScheduleCallback(const TScheduledCallback& callback)
246261
{
262+
std::shared_lock<std::shared_mutex> lock(m_writeCallbackMutex);
263+
247264
if (m_writeCallback)
248265
{
249266
m_pendingRequests.push(callback);
@@ -302,11 +319,15 @@ void UvTcpServerStream::Close()
302319
});
303320

304321
// wake the callback
305-
auto wc = m_writeCallback.get();
306-
307-
if (wc)
308322
{
309-
uv_async_send(wc);
323+
std::shared_lock<std::shared_mutex> lock(m_writeCallbackMutex);
324+
325+
auto wc = m_writeCallback.get();
326+
327+
if (wc)
328+
{
329+
uv_async_send(wc);
330+
}
310331
}
311332
}
312333
}

0 commit comments

Comments
 (0)