diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h index 0deb7e35fab6c..ce742be7a941c 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h +++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h @@ -158,6 +158,7 @@ class MessageHandler { template OutgoingNotification outgoingNotification(llvm::StringLiteral method) { return [&, method](const T ¶ms) { + std::lock_guard transportLock(transportOutputMutex); Logger::info("--> {0}", method); transport.notify(method, llvm::json::Value(params)); }; @@ -172,6 +173,9 @@ class MessageHandler { methodHandlers; JSONTransport &transport; + + /// Mutex to guard sending output messages to the transport. + std::mutex transportOutputMutex; }; } // namespace lsp diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp index 92171f1a053ae..3915146f6a66d 100644 --- a/mlir/lib/Tools/lsp-server-support/Transport.cpp +++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp @@ -30,8 +30,8 @@ namespace { /// - if there were multiple replies, only the first is sent class Reply { public: - Reply(const llvm::json::Value &id, StringRef method, - JSONTransport &transport); + Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport, + std::mutex &transportOutputMutex); Reply(Reply &&other); Reply &operator=(Reply &&) = delete; Reply(const Reply &) = delete; @@ -44,16 +44,19 @@ class Reply { std::atomic replied = {false}; llvm::json::Value id; JSONTransport *transport; + std::mutex &transportOutputMutex; }; } // namespace Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, - JSONTransport &transport) - : id(id), transport(&transport) {} + JSONTransport &transport, std::mutex &transportOutputMutex) + : id(id), transport(&transport), + transportOutputMutex(transportOutputMutex) {} Reply::Reply(Reply &&other) : replied(other.replied.load()), id(std::move(other.id)), - transport(other.transport) { + transport(other.transport), + transportOutputMutex(other.transportOutputMutex) { other.transport = nullptr; } @@ -65,6 +68,7 @@ void Reply::operator()(llvm::Expected reply) { } assert(transport && "expected valid transport to reply to"); + std::lock_guard transportLock(transportOutputMutex); if (reply) { Logger::info("--> reply:{0}({1})", method, id); transport->reply(std::move(id), std::move(reply)); @@ -98,7 +102,7 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, llvm::json::Value id) { Logger::info("--> {0}({1})", method, id); - Reply reply(id, method, transport); + Reply reply(id, method, transport, transportOutputMutex); auto it = methodHandlers.find(method); if (it != methodHandlers.end()) {