diff --git a/Makefile b/Makefile index 2a049f03f..d67ddc5c8 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,8 @@ test-echo: all kill -9 `cat echo.pid` diff -u test.out Sources/Examples/Echo/test.gold -test-plugin: all +test-plugin: + swift build -v $(CFLAGS) --product protoc-gen-swiftgrpc protoc Sources/Examples/Echo/echo.proto --proto_path=Sources/Examples/Echo --plugin=.build/debug/protoc-gen-swift --plugin=.build/debug/protoc-gen-swiftgrpc --swiftgrpc_out=/tmp --swiftgrpc_opt=TestStubs=true diff -u /tmp/echo.grpc.swift Sources/Examples/Echo/Generated/echo.grpc.swift diff --git a/Sources/Examples/Echo/Generated/echo.grpc.swift b/Sources/Examples/Echo/Generated/echo.grpc.swift index 0534e12aa..f81e02878 100644 --- a/Sources/Examples/Echo/Generated/echo.grpc.swift +++ b/Sources/Examples/Echo/Generated/echo.grpc.swift @@ -32,12 +32,17 @@ fileprivate final class Echo_EchoGetCallBase: ClientCallUnaryBase Echo_EchoResponse? + /// Do not call this directly, call `receive()` in the protocol extension below instead. + func _receive(timeout: DispatchTime) throws -> Echo_EchoResponse? /// Call this to wait for a result. Nonblocking. func receive(completion: @escaping (ResultOrRPCError) -> Void) throws } +internal extension Echo_EchoExpandCall { + /// Call this to wait for a result. Blocking. + func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoResponse? { return try self._receive(timeout: timeout) } +} + fileprivate final class Echo_EchoExpandCallBase: ClientCallServerStreamingBase, Echo_EchoExpandCall { override class var method: String { return "/echo.Echo/Expand" } } @@ -49,8 +54,8 @@ class Echo_EchoExpandCallTestStub: ClientCallServerStreamingTestStub Void) throws - /// Send a message to the stream and wait for the send operation to finish. Blocking. - func send(_ message: Echo_EchoRequest) throws + /// Do not call this directly, call `send()` in the protocol extension below instead. + func _send(_ message: Echo_EchoRequest, timeout: DispatchTime) throws /// Call this to close the connection and wait for a response. Blocking. func closeAndReceive() throws -> Echo_EchoResponse @@ -58,6 +63,11 @@ internal protocol Echo_EchoCollectCall: ClientCallClientStreaming { func closeAndReceive(completion: @escaping (ResultOrRPCError) -> Void) throws } +internal extension Echo_EchoCollectCall { + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoRequest, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) } +} + fileprivate final class Echo_EchoCollectCallBase: ClientCallClientStreamingBase, Echo_EchoCollectCall { override class var method: String { return "/echo.Echo/Collect" } } @@ -69,15 +79,15 @@ class Echo_EchoCollectCallTestStub: ClientCallClientStreamingTestStub Echo_EchoResponse? + /// Do not call this directly, call `receive()` in the protocol extension below instead. + func _receive(timeout: DispatchTime) throws -> Echo_EchoResponse? /// Call this to wait for a result. Nonblocking. func receive(completion: @escaping (ResultOrRPCError) -> Void) throws /// Send a message to the stream. Nonblocking. func send(_ message: Echo_EchoRequest, completion: @escaping (Error?) -> Void) throws - /// Send a message to the stream and wait for the send operation to finish. Blocking. - func send(_ message: Echo_EchoRequest) throws + /// Do not call this directly, call `send()` in the protocol extension below instead. + func _send(_ message: Echo_EchoRequest, timeout: DispatchTime) throws /// Call this to close the sending connection. Blocking. func closeSend() throws @@ -85,6 +95,16 @@ internal protocol Echo_EchoUpdateCall: ClientCallBidirectionalStreaming { func closeSend(completion: (() -> Void)?) throws } +internal extension Echo_EchoUpdateCall { + /// Call this to wait for a result. Blocking. + func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoResponse? { return try self._receive(timeout: timeout) } +} + +internal extension Echo_EchoUpdateCall { + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoRequest, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) } +} + fileprivate final class Echo_EchoUpdateCallBase: ClientCallBidirectionalStreamingBase, Echo_EchoUpdateCall { override class var method: String { return "/echo.Echo/Update" } } @@ -207,21 +227,26 @@ class Echo_EchoGetSessionTestStub: ServerSessionUnaryTestStub, Echo_EchoGetSessi internal protocol Echo_EchoExpandSession: ServerSessionServerStreaming { /// Send a message to the stream. Nonblocking. func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws - /// Send a message to the stream and wait for the send operation to finish. Blocking. - func send(_ message: Echo_EchoResponse) throws + /// Do not call this directly, call `send()` in the protocol extension below instead. + func _send(_ message: Echo_EchoResponse, timeout: DispatchTime) throws /// Close the connection and send the status. Non-blocking. /// You MUST call this method once you are done processing the request. func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws } +internal extension Echo_EchoExpandSession { + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoResponse, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) } +} + fileprivate final class Echo_EchoExpandSessionBase: ServerSessionServerStreamingBase, Echo_EchoExpandSession {} class Echo_EchoExpandSessionTestStub: ServerSessionServerStreamingTestStub, Echo_EchoExpandSession {} internal protocol Echo_EchoCollectSession: ServerSessionClientStreaming { - /// Call this to wait for a result. Blocking. - func receive() throws -> Echo_EchoRequest? + /// Do not call this directly, call `receive()` in the protocol extension below instead. + func _receive(timeout: DispatchTime) throws -> Echo_EchoRequest? /// Call this to wait for a result. Nonblocking. func receive(completion: @escaping (ResultOrRPCError) -> Void) throws @@ -234,26 +259,41 @@ internal protocol Echo_EchoCollectSession: ServerSessionClientStreaming { func sendErrorAndClose(status: ServerStatus, completion: (() -> Void)?) throws } +internal extension Echo_EchoCollectSession { + /// Call this to wait for a result. Blocking. + func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoRequest? { return try self._receive(timeout: timeout) } +} + fileprivate final class Echo_EchoCollectSessionBase: ServerSessionClientStreamingBase, Echo_EchoCollectSession {} class Echo_EchoCollectSessionTestStub: ServerSessionClientStreamingTestStub, Echo_EchoCollectSession {} internal protocol Echo_EchoUpdateSession: ServerSessionBidirectionalStreaming { - /// Call this to wait for a result. Blocking. - func receive() throws -> Echo_EchoRequest? + /// Do not call this directly, call `receive()` in the protocol extension below instead. + func _receive(timeout: DispatchTime) throws -> Echo_EchoRequest? /// Call this to wait for a result. Nonblocking. func receive(completion: @escaping (ResultOrRPCError) -> Void) throws /// Send a message to the stream. Nonblocking. func send(_ message: Echo_EchoResponse, completion: @escaping (Error?) -> Void) throws - /// Send a message to the stream and wait for the send operation to finish. Blocking. - func send(_ message: Echo_EchoResponse) throws + /// Do not call this directly, call `send()` in the protocol extension below instead. + func _send(_ message: Echo_EchoResponse, timeout: DispatchTime) throws /// Close the connection and send the status. Non-blocking. /// You MUST call this method once you are done processing the request. func close(withStatus status: ServerStatus, completion: (() -> Void)?) throws } +internal extension Echo_EchoUpdateSession { + /// Call this to wait for a result. Blocking. + func receive(timeout: DispatchTime = .distantFuture) throws -> Echo_EchoRequest? { return try self._receive(timeout: timeout) } +} + +internal extension Echo_EchoUpdateSession { + /// Send a message to the stream and wait for the send operation to finish. Blocking. + func send(_ message: Echo_EchoResponse, timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) } +} + fileprivate final class Echo_EchoUpdateSessionBase: ServerSessionBidirectionalStreamingBase, Echo_EchoUpdateSession {} class Echo_EchoUpdateSessionTestStub: ServerSessionBidirectionalStreamingTestStub, Echo_EchoUpdateSession {} diff --git a/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift b/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift index 9b2fd3d11..50ee2439f 100644 --- a/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ClientCallBidirectionalStreaming.swift @@ -59,20 +59,20 @@ open class ClientCallBidirectionalStreamingTestStub OutputType? { + open func _receive(timeout: DispatchTime) throws -> OutputType? { defer { if !outputs.isEmpty { outputs.removeFirst() } } return outputs.first } open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { - completion(.result(try self.receive())) + completion(.result(try self._receive(timeout: .distantFuture))) } open func send(_ message: InputType, completion _: @escaping (Error?) -> Void) throws { inputs.append(message) } - open func send(_ message: InputType) throws { + open func _send(_ message: InputType, timeout: DispatchTime) throws { inputs.append(message) } diff --git a/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift b/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift index db515e108..8a0840696 100644 --- a/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ClientCallClientStreaming.swift @@ -76,7 +76,7 @@ open class ClientCallClientStreamingTestStub: ClientCallSer public init() {} - open func receive() throws -> OutputType? { + open func _receive(timeout: DispatchTime) throws -> OutputType? { defer { if !outputs.isEmpty { outputs.removeFirst() } } return outputs.first } open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { - completion(.result(try self.receive())) + completion(.result(try self._receive(timeout: .distantFuture))) } open func cancel() {} diff --git a/Sources/SwiftGRPC/Runtime/RPCError.swift b/Sources/SwiftGRPC/Runtime/RPCError.swift index 28ebfd42c..bb81ca404 100644 --- a/Sources/SwiftGRPC/Runtime/RPCError.swift +++ b/Sources/SwiftGRPC/Runtime/RPCError.swift @@ -19,13 +19,14 @@ import Foundation /// Type for errors thrown from generated client code. public enum RPCError: Error { case invalidMessageReceived + case timedOut case callError(CallResult) } public extension RPCError { var callResult: CallResult? { switch self { - case .invalidMessageReceived: return nil + case .invalidMessageReceived, .timedOut: return nil case .callError(let callResult): return callResult } } diff --git a/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift b/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift index 6c655f419..ad758c0c6 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSessionBidirectionalStreaming.swift @@ -69,20 +69,20 @@ open class ServerSessionBidirectionalStreamingTestStub InputType? { + open func _receive(timeout: DispatchTime) throws -> InputType? { defer { if !inputs.isEmpty { inputs.removeFirst() } } return inputs.first } open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { - completion(.result(try self.receive())) + completion(.result(try self._receive(timeout: .distantFuture))) } open func send(_ message: OutputType, completion _: @escaping (Error?) -> Void) throws { outputs.append(message) } - open func send(_ message: OutputType) throws { + open func _send(_ message: OutputType, timeout: DispatchTime) throws { outputs.append(message) } diff --git a/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift b/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift index b1a943502..ce2da3c21 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSessionClientStreaming.swift @@ -75,13 +75,13 @@ open class ServerSessionClientStreamingTestStub InputType? { + open func _receive(timeout: DispatchTime) throws -> InputType? { defer { if !inputs.isEmpty { inputs.removeFirst() } } return inputs.first } open func receive(completion: @escaping (ResultOrRPCError) -> Void) throws { - completion(.result(try self.receive())) + completion(.result(try self._receive(timeout: .distantFuture))) } open func sendAndClose(response: OutputType, status: ServerStatus, completion: (() -> Void)?) throws { diff --git a/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift b/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift index a02e3ff92..a71cf3f11 100644 --- a/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift +++ b/Sources/SwiftGRPC/Runtime/ServerSessionServerStreaming.swift @@ -72,7 +72,7 @@ open class ServerSessionServerStreamingTestStub: ServerSess outputs.append(message) } - open func send(_ message: OutputType) throws { + open func _send(_ message: OutputType, timeout: DispatchTime) throws { outputs.append(message) } diff --git a/Sources/SwiftGRPC/Runtime/StreamReceiving.swift b/Sources/SwiftGRPC/Runtime/StreamReceiving.swift index 15710bf7c..fd732fc1a 100644 --- a/Sources/SwiftGRPC/Runtime/StreamReceiving.swift +++ b/Sources/SwiftGRPC/Runtime/StreamReceiving.swift @@ -43,14 +43,16 @@ extension StreamReceiving { } } - public func receive() throws -> ReceivedType? { + public func _receive(timeout: DispatchTime) throws -> ReceivedType? { var result: ResultOrRPCError? let sem = DispatchSemaphore(value: 0) try receive { result = $0 sem.signal() } - _ = sem.wait() + if sem.wait(timeout: timeout) == .timedOut { + throw RPCError.timedOut + } switch result! { case .result(let response): return response case .error(let error): throw error diff --git a/Sources/SwiftGRPC/Runtime/StreamSending.swift b/Sources/SwiftGRPC/Runtime/StreamSending.swift index e0ca459cc..f1c3037b1 100644 --- a/Sources/SwiftGRPC/Runtime/StreamSending.swift +++ b/Sources/SwiftGRPC/Runtime/StreamSending.swift @@ -29,14 +29,16 @@ extension StreamSending { try call.sendMessage(data: message.serializedData(), completion: completion) } - public func send(_ message: SentType) throws { + public func _send(_ message: SentType, timeout: DispatchTime) throws { var resultError: Error? let sem = DispatchSemaphore(value: 0) try send(message) { resultError = $0 sem.signal() } - _ = sem.wait() + if sem.wait(timeout: timeout) == .timedOut { + throw RPCError.timedOut + } if let resultError = resultError { throw resultError } diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Client.swift b/Sources/protoc-gen-swiftgrpc/Generator-Client.swift index b283fe1cd..b3b24b465 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Client.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Client.swift @@ -60,6 +60,8 @@ extension Generator { outdent() println("}") println() + printStreamReceiveExtension(extendedType: callName, receivedType: methodOutputName) + println() println("fileprivate final class \(callName)Base: ClientCallServerStreamingBase<\(methodInputName), \(methodOutputName)>, \(callName) {") indent() println("override class var method: String { return \(methodPath) }") @@ -88,6 +90,8 @@ extension Generator { outdent() println("}") println() + printStreamSendExtension(extendedType: callName, sentType: methodInputName) + println() println("fileprivate final class \(callName)Base: ClientCallClientStreamingBase<\(methodInputName), \(methodOutputName)>, \(callName) {") indent() println("override class var method: String { return \(methodPath) }") @@ -120,6 +124,10 @@ extension Generator { outdent() println("}") println() + printStreamReceiveExtension(extendedType: callName, receivedType: methodOutputName) + println() + printStreamSendExtension(extendedType: callName, sentType: methodInputName) + println() println("fileprivate final class \(callName)Base: ClientCallBidirectionalStreamingBase<\(methodInputName), \(methodOutputName)>, \(callName) {") indent() println("override class var method: String { return \(methodPath) }") diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Methods.swift b/Sources/protoc-gen-swiftgrpc/Generator-Methods.swift index f32f66157..7a1ad32de 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Methods.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Methods.swift @@ -19,16 +19,34 @@ import SwiftProtobufPluginLibrary extension Generator { func printStreamReceiveMethods(receivedType: String) { - println("/// Call this to wait for a result. Blocking.") - println("func receive() throws -> \(receivedType)?") + println("/// Do not call this directly, call `receive()` in the protocol extension below instead.") + println("func _receive(timeout: DispatchTime) throws -> \(receivedType)?") println("/// Call this to wait for a result. Nonblocking.") println("func receive(completion: @escaping (ResultOrRPCError<\(receivedType)?>) -> Void) throws") } + func printStreamReceiveExtension(extendedType: String, receivedType: String) { + println("\(access) extension \(extendedType) {") + indent() + println("/// Call this to wait for a result. Blocking.") + println("func receive(timeout: DispatchTime = .distantFuture) throws -> \(receivedType)? { return try self._receive(timeout: timeout) }") + outdent() + println("}") + } + func printStreamSendMethods(sentType: String) { println("/// Send a message to the stream. Nonblocking.") println("func send(_ message: \(sentType), completion: @escaping (Error?) -> Void) throws") + println("/// Do not call this directly, call `send()` in the protocol extension below instead.") + println("func _send(_ message: \(sentType), timeout: DispatchTime) throws") + } + + func printStreamSendExtension(extendedType: String,sentType: String) { + println("\(access) extension \(extendedType) {") + indent() println("/// Send a message to the stream and wait for the send operation to finish. Blocking.") - println("func send(_ message: \(sentType)) throws") + println("func send(_ message: \(sentType), timeout: DispatchTime = .distantFuture) throws { try self._send(message, timeout: timeout) }") + outdent() + println("}") } } diff --git a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift index 34171ef41..76655a257 100644 --- a/Sources/protoc-gen-swiftgrpc/Generator-Server.swift +++ b/Sources/protoc-gen-swiftgrpc/Generator-Server.swift @@ -161,6 +161,8 @@ extension Generator { outdent() println("}") println() + printStreamReceiveExtension(extendedType: methodSessionName, receivedType: methodInputName) + println() println("fileprivate final class \(methodSessionName)Base: ServerSessionClientStreamingBase<\(methodInputName), \(methodOutputName)>, \(methodSessionName) {}") if options.generateTestStubs { println() @@ -183,6 +185,8 @@ extension Generator { outdent() println("}") println() + printStreamSendExtension(extendedType: methodSessionName, sentType: methodOutputName) + println() println("fileprivate final class \(methodSessionName)Base: ServerSessionServerStreamingBase<\(methodInputName), \(methodOutputName)>, \(methodSessionName) {}") if options.generateTestStubs { println() @@ -201,6 +205,10 @@ extension Generator { outdent() println("}") println() + printStreamReceiveExtension(extendedType: methodSessionName, receivedType: methodInputName) + println() + printStreamSendExtension(extendedType: methodSessionName, sentType: methodOutputName) + println() println("fileprivate final class \(methodSessionName)Base: ServerSessionBidirectionalStreamingBase<\(methodInputName), \(methodOutputName)>, \(methodSessionName) {}") if options.generateTestStubs { println() diff --git a/Tests/SwiftGRPCTests/ClientTimeoutTests.swift b/Tests/SwiftGRPCTests/ClientTimeoutTests.swift index 2ea98109d..8137752eb 100644 --- a/Tests/SwiftGRPCTests/ClientTimeoutTests.swift +++ b/Tests/SwiftGRPCTests/ClientTimeoutTests.swift @@ -79,6 +79,29 @@ extension ClientTimeoutTests { waitForExpectations(timeout: defaultTimeout) } + func testBidirectionalStreamingTimeoutPassedToReceiveMethod() { + let completionHandlerExpectation = expectation(description: "final completion handler called") + let call = try! client.update { callResult in + XCTAssertEqual(.ok, callResult.statusCode) + completionHandlerExpectation.fulfill() + } + + do { + let result = try call.receive(timeout: .now() + .milliseconds(10)) + XCTFail("should have thrown, received \(String(describing: result)) instead") + } catch let receiveError { + if case .timedOut = receiveError as! RPCError { + // This is the expected case - we need to formulate this as an if statement to use case-based pattern matching. + } else { + XCTFail("received error \(receiveError) instead of .timedOut") + } + } + + try! call.closeSend() + + waitForExpectations(timeout: defaultTimeout) + } + // FIXME(danielalm): Add support for setting a maximum timeout on the server, to prevent DoS attacks where clients // start a ton of calls, but never finish them (i.e. essentially leaking a connection on the server side). }