diff --git a/Sources/GRPCCodeGen/Internal/Renderer/TextBasedRenderer.swift b/Sources/GRPCCodeGen/Internal/Renderer/TextBasedRenderer.swift index 85a2cc264..54e642545 100644 --- a/Sources/GRPCCodeGen/Internal/Renderer/TextBasedRenderer.swift +++ b/Sources/GRPCCodeGen/Internal/Renderer/TextBasedRenderer.swift @@ -202,11 +202,11 @@ struct TextBasedRenderer: RendererProtocol { } /// Renders the specified identifier. - func renderedIdentifier(_ identifier: IdentifierDescription) -> String { + func renderIdentifier(_ identifier: IdentifierDescription) { switch identifier { - case .pattern(let string): return string + case .pattern(let string): writer.writeLine(string) case .type(let existingTypeDescription): - return renderedExistingTypeDescription(existingTypeDescription) + renderExistingTypeDescription(existingTypeDescription) } } @@ -446,7 +446,7 @@ struct TextBasedRenderer: RendererProtocol { switch expression { case .literal(let literalDescription): renderLiteral(literalDescription) case .identifier(let identifierDescription): - writer.writeLine(renderedIdentifier(identifierDescription)) + renderIdentifier(identifierDescription) case .memberAccess(let memberAccessDescription): renderMemberAccess(memberAccessDescription) case .functionCall(let functionCallDescription): renderFunctionCall(functionCallDescription) case .assignment(let assignment): renderAssignment(assignment) @@ -534,20 +534,44 @@ struct TextBasedRenderer: RendererProtocol { } /// Renders the specified type reference to an existing type. - func renderedExistingTypeDescription(_ type: ExistingTypeDescription) -> String { + func renderExistingTypeDescription(_ type: ExistingTypeDescription) { switch type { case .any(let existingTypeDescription): - return "any \(renderedExistingTypeDescription(existingTypeDescription))" + writer.writeLine("any ") + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(existingTypeDescription) case .generic(let wrapper, let wrapped): - return - "\(renderedExistingTypeDescription(wrapper))<\(renderedExistingTypeDescription(wrapped))>" + renderExistingTypeDescription(wrapper) + writer.nextLineAppendsToLastLine() + writer.writeLine("<") + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(wrapped) + writer.nextLineAppendsToLastLine() + writer.writeLine(">") case .optional(let existingTypeDescription): - return "\(renderedExistingTypeDescription(existingTypeDescription))?" - case .member(let components): return components.joined(separator: ".") + renderExistingTypeDescription(existingTypeDescription) + writer.nextLineAppendsToLastLine() + writer.writeLine("?") + case .member(let components): + writer.writeLine(components.joined(separator: ".")) case .array(let existingTypeDescription): - return "[\(renderedExistingTypeDescription(existingTypeDescription))]" + writer.writeLine("[") + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(existingTypeDescription) + writer.nextLineAppendsToLastLine() + writer.writeLine("]") case .dictionaryValue(let existingTypeDescription): - return "[String: \(renderedExistingTypeDescription(existingTypeDescription))]" + writer.writeLine("[String: ") + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(existingTypeDescription) + writer.nextLineAppendsToLastLine() + writer.writeLine("]") + case .some(let existingTypeDescription): + writer.writeLine("some ") + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(existingTypeDescription) + case .closure(let closureSignatureDescription): + renderClosureSignature(closureSignatureDescription) } } @@ -558,9 +582,11 @@ struct TextBasedRenderer: RendererProtocol { words.append(renderedAccessModifier(accessModifier)) } words.append(contentsOf: [ - "typealias", alias.name, "=", renderedExistingTypeDescription(alias.existingType), + "typealias", alias.name, "=", ]) - writer.writeLine(words.joinedWords()) + writer.writeLine(words.joinedWords() + " ") + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(alias.existingType) } /// Renders the specified binding kind. @@ -587,7 +613,9 @@ struct TextBasedRenderer: RendererProtocol { renderExpression(variable.left) if let type = variable.type { writer.nextLineAppendsToLastLine() - writer.writeLine(": \(renderedExistingTypeDescription(type))") + writer.writeLine(": ") + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(type) } } @@ -703,11 +731,12 @@ struct TextBasedRenderer: RendererProtocol { } /// Renders the specified enum case associated value. - func renderedEnumCaseAssociatedValue(_ value: EnumCaseAssociatedValueDescription) -> String { + func renderEnumCaseAssociatedValue(_ value: EnumCaseAssociatedValueDescription) { var words: [String] = [] if let label = value.label { words.append(label + ":") } - words.append(renderedExistingTypeDescription(value.type)) - return words.joinedWords() + writer.writeLine(words.joinedWords()) + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(value.type) } /// Renders the specified enum case declaration. @@ -722,9 +751,13 @@ struct TextBasedRenderer: RendererProtocol { renderLiteral(rawValue) case .nameWithAssociatedValues(let values): if values.isEmpty { break } - let associatedValues = values.map(renderedEnumCaseAssociatedValue).joined(separator: ", ") - writer.nextLineAppendsToLastLine() - writer.writeLine("(\(associatedValues))") + for (value, isLast) in values.enumeratedWithLastMarker() { + renderEnumCaseAssociatedValue(value) + if !isLast { + writer.nextLineAppendsToLastLine() + writer.writeLine(", ") + } + } } } @@ -750,7 +783,9 @@ struct TextBasedRenderer: RendererProtocol { func renderedFunctionKind(_ functionKind: FunctionKind) -> String { switch functionKind { case .initializer(let isFailable): return "init\(isFailable ? "?" : "")" - case .function(let name, let isStatic): return (isStatic ? "static " : "") + "func \(name)" + case .function(let name, let isStatic): + return (isStatic ? "static " : "") + "func \(name)" + } } @@ -759,6 +794,54 @@ struct TextBasedRenderer: RendererProtocol { switch keyword { case .throws: return "throws" case .async: return "async" + case .rethrows: return "rethrows" + } + } + + /// Renders the specified function signature. + func renderClosureSignature(_ signature: ClosureSignatureDescription) { + if signature.sendable { + writer.writeLine("@Sendable ") + writer.nextLineAppendsToLastLine() + } + if signature.escaping { + writer.writeLine("@escaping ") + writer.nextLineAppendsToLastLine() + } + + writer.writeLine("(") + let parameters = signature.parameters + let separateLines = parameters.count > 1 + if separateLines { + writer.withNestedLevel { + for (parameter, isLast) in signature.parameters.enumeratedWithLastMarker() { + renderClosureParameter(parameter) + if !isLast { + writer.nextLineAppendsToLastLine() + writer.writeLine(",") + } + } + } + } else { + writer.nextLineAppendsToLastLine() + if let parameter = parameters.first { + renderClosureParameter(parameter) + writer.nextLineAppendsToLastLine() + } + } + writer.writeLine(")") + + let keywords = signature.keywords + for keyword in keywords { + writer.nextLineAppendsToLastLine() + writer.writeLine(" " + renderedFunctionKeyword(keyword)) + } + + if let returnType = signature.returnType { + writer.nextLineAppendsToLastLine() + writer.writeLine(" -> ") + writer.nextLineAppendsToLastLine() + renderExpression(returnType) } } @@ -769,7 +852,26 @@ struct TextBasedRenderer: RendererProtocol { writer.writeLine(renderedAccessModifier(accessModifier) + " ") writer.nextLineAppendsToLastLine() } - writer.writeLine(renderedFunctionKind(signature.kind) + "(") + let generics = signature.generics + writer.writeLine( + renderedFunctionKind(signature.kind) + ) + if !generics.isEmpty { + writer.nextLineAppendsToLastLine() + writer.writeLine("<") + for (genericType, isLast) in generics.enumeratedWithLastMarker() { + writer.nextLineAppendsToLastLine() + renderExistingTypeDescription(genericType) + if !isLast { + writer.nextLineAppendsToLastLine() + writer.writeLine(", ") + } + } + writer.nextLineAppendsToLastLine() + writer.writeLine(">") + } + writer.nextLineAppendsToLastLine() + writer.writeLine("(") let parameters = signature.parameters let separateLines = parameters.count > 1 if separateLines { @@ -806,6 +908,11 @@ struct TextBasedRenderer: RendererProtocol { writer.nextLineAppendsToLastLine() renderExpression(returnType) } + + if let whereClause = signature.whereClause { + writer.nextLineAppendsToLastLine() + writer.writeLine(" " + renderedWhereClause(whereClause)) + } } /// Renders the specified function declaration. @@ -839,11 +946,54 @@ struct TextBasedRenderer: RendererProtocol { } writer.writeLine(": ") writer.nextLineAppendsToLastLine() + + if parameterDescription.inout { + writer.writeLine("inout ") + writer.nextLineAppendsToLastLine() + } + + if let type = parameterDescription.type { + renderExistingTypeDescription(type) + } + + if let defaultValue = parameterDescription.defaultValue { + writer.nextLineAppendsToLastLine() + writer.writeLine(" = ") + writer.nextLineAppendsToLastLine() + renderExpression(defaultValue) + } + } + + /// Renders the specified parameter declaration for a closure. + func renderClosureParameter(_ parameterDescription: ParameterDescription) { + let name = parameterDescription.name + let label: String + if let declaredLabel = parameterDescription.label { + label = declaredLabel + } else { + label = "_" + } + + if let name = name { + writer.writeLine(label) + if name != parameterDescription.label { + // If the label and name are the same value, don't repeat it. + writer.writeLine(" ") + writer.nextLineAppendsToLastLine() + writer.writeLine(name) + writer.nextLineAppendsToLastLine() + } + } + if parameterDescription.inout { writer.writeLine("inout ") writer.nextLineAppendsToLastLine() } - writer.writeLine(renderedExistingTypeDescription(parameterDescription.type)) + + if let type = parameterDescription.type { + renderExistingTypeDescription(type) + } + if let defaultValue = parameterDescription.defaultValue { writer.nextLineAppendsToLastLine() writer.writeLine(" = ") diff --git a/Sources/GRPCCodeGen/Internal/StructuredSwiftRepresentation.swift b/Sources/GRPCCodeGen/Internal/StructuredSwiftRepresentation.swift index fc7ce726e..556c3d909 100644 --- a/Sources/GRPCCodeGen/Internal/StructuredSwiftRepresentation.swift +++ b/Sources/GRPCCodeGen/Internal/StructuredSwiftRepresentation.swift @@ -420,6 +420,16 @@ indirect enum ExistingTypeDescription: Equatable, Codable { /// /// For example, `[String: Foo]`. case dictionaryValue(ExistingTypeDescription) + + /// A type with the `some` keyword in front of it. + /// + /// For example, `some Foo`. + case some(ExistingTypeDescription) + + /// A closure signature as a type. + /// + /// For example: `(String) async throws -> Int`. + case closure(ClosureSignatureDescription) } /// A description of a typealias declaration. @@ -483,7 +493,7 @@ struct ParameterDescription: Equatable, Codable { /// The type name of the parameter. /// /// For example, in `bar baz: String = "hi"`, `type` is `String`. - var type: ExistingTypeDescription + var type: ExistingTypeDescription? = nil /// A default value of the parameter. /// @@ -508,7 +518,10 @@ enum FunctionKind: Equatable, Codable { /// A function or a method. Can be static. /// /// For example `foo()`, where `name` is `foo`. - case function(name: String, isStatic: Bool) + case function( + name: String, + isStatic: Bool + ) } /// A function keyword, such as `async` and `throws`. @@ -519,6 +532,9 @@ enum FunctionKeyword: Equatable, Codable { /// A function that can throw an error. case `throws` + + /// A function that can rethrow an error. + case `rethrows` } /// A description of a function signature. @@ -532,6 +548,9 @@ struct FunctionSignatureDescription: Equatable, Codable { /// The kind of the function. var kind: FunctionKind + /// The generic types of the function. + var generics: [ExistingTypeDescription] = [] + /// The parameters of the function. var parameters: [ParameterDescription] = [] @@ -540,6 +559,9 @@ struct FunctionSignatureDescription: Equatable, Codable { /// The return type name of the function, such as `Int`. var returnType: Expression? = nil + + /// The where clause for a generic function. + var whereClause: WhereClause? } /// A description of a function definition. @@ -575,17 +597,21 @@ struct FunctionDescription: Equatable, Codable { init( accessModifier: AccessModifier? = nil, kind: FunctionKind, + generics: [ExistingTypeDescription] = [], parameters: [ParameterDescription] = [], keywords: [FunctionKeyword] = [], returnType: Expression? = nil, + whereClause: WhereClause? = nil, body: [CodeBlock]? = nil ) { self.signature = .init( accessModifier: accessModifier, kind: kind, + generics: generics, parameters: parameters, keywords: keywords, - returnType: returnType + returnType: returnType, + whereClause: whereClause ) self.body = body } @@ -601,22 +627,45 @@ struct FunctionDescription: Equatable, Codable { init( accessModifier: AccessModifier? = nil, kind: FunctionKind, + generics: [ExistingTypeDescription] = [], parameters: [ParameterDescription] = [], keywords: [FunctionKeyword] = [], returnType: Expression? = nil, + whereClause: WhereClause? = nil, body: [Expression] ) { self.init( accessModifier: accessModifier, kind: kind, + generics: generics, parameters: parameters, keywords: keywords, returnType: returnType, + whereClause: whereClause, body: body.map { .expression($0) } ) } } +/// A description of a closure signature. +/// +/// For example: `(String) async throws -> Int`. +struct ClosureSignatureDescription: Equatable, Codable { + /// The parameters of the function. + var parameters: [ParameterDescription] = [] + + /// The keywords of the function, such as `async` and `throws.` + var keywords: [FunctionKeyword] = [] + + /// The return type name of the function, such as `Int`. + var returnType: Expression? = nil + + /// The ``@Sendable`` attribute. + var sendable: Bool = false + + /// The ``@escaping`` attribute. + var escaping: Bool = false +} /// A description of the associated value of an enum case. /// /// For example, in `case foo(bar: String)`, the associated value @@ -1235,18 +1284,22 @@ extension Declaration { static func function( accessModifier: AccessModifier? = nil, kind: FunctionKind, + generics: [ExistingTypeDescription] = [], parameters: [ParameterDescription], keywords: [FunctionKeyword] = [], returnType: Expression? = nil, + whereClause: WhereClause?, body: [CodeBlock]? = nil ) -> Self { .function( .init( accessModifier: accessModifier, kind: kind, + generics: generics, parameters: parameters, keywords: keywords, returnType: returnType, + whereClause: whereClause, body: body ) ) @@ -1327,7 +1380,9 @@ extension FunctionKind { static var initializer: Self { .initializer(failable: false) } /// Returns a non-static function kind. - static func function(name: String) -> Self { .function(name: name, isStatic: false) } + static func function(name: String) -> Self { + .function(name: name, isStatic: false) + } } extension CodeBlock { diff --git a/Sources/GRPCCodeGen/Internal/Translator/ClientCodeTranslator.swift b/Sources/GRPCCodeGen/Internal/Translator/ClientCodeTranslator.swift new file mode 100644 index 000000000..385f513f7 --- /dev/null +++ b/Sources/GRPCCodeGen/Internal/Translator/ClientCodeTranslator.swift @@ -0,0 +1,432 @@ +/* + * Copyright 2023, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// Creates a representation for the client code that will be generated based on the ``CodeGenerationRequest`` object +/// specifications, using types from ``StructuredSwiftRepresentation``. +/// +/// For example, in the case of a service called "Bar", in the "foo" namespace which has +/// one method "baz", the ``ClientCodeTranslator`` will create +/// a representation for the following generated code: +/// +/// ```swift +/// public protocol foo_BarClientProtocol: Sendable { +/// func baz( +/// request: ClientRequest.Single, +/// serializer: some MessageSerializer, +/// deserializer: some MessageDeserializer, +/// _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R +/// ) async throws -> ServerResponse.Stream +/// } +/// extension foo.Bar.ClientProtocol { +/// public func get( +/// request: ClientRequest.Single, +/// _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R +/// ) async rethrows -> R { +/// try await self.baz( +/// request: request, +/// serializer: ProtobufSerializer(), +/// deserializer: ProtobufDeserializer(), +/// body +/// ) +/// } +/// struct foo_BarClient: foo.Bar.ClientProtocol { +/// let client: GRPCCore.GRPCClient +/// init(client: GRPCCore.GRPCClient) { +/// self.client = client +/// } +/// func methodA( +/// request: ClientRequest.Stream, +/// serializer: some MessageSerializer, +/// deserializer: some MessageDeserializer, +/// _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R +/// ) async rethrows -> R { +/// try await self.client.clientStreaming( +/// request: request, +/// descriptor: namespaceA.ServiceA.Methods.methodA.descriptor, +/// serializer: serializer, +/// deserializer: deserializer, +/// handler: body +/// ) +/// } +/// } +///``` +struct ClientCodeTranslator: SpecializedTranslator { + func translate(from codeGenerationRequest: CodeGenerationRequest) throws -> [CodeBlock] { + var codeBlocks = [CodeBlock]() + + for service in codeGenerationRequest.services { + codeBlocks.append( + .declaration( + .commentable( + .doc(service.documentation), + self.makeClientProtocol(for: service, in: codeGenerationRequest) + ) + ) + ) + codeBlocks.append( + .declaration(self.makeExtensionProtocol(for: service, in: codeGenerationRequest)) + ) + codeBlocks.append( + .declaration( + .commentable( + .doc(service.documentation), + self.makeClientStruct(for: service, in: codeGenerationRequest) + ) + ) + ) + } + return codeBlocks + } +} + +extension ClientCodeTranslator { + private func makeClientProtocol( + for service: CodeGenerationRequest.ServiceDescriptor, + in codeGenerationRequest: CodeGenerationRequest + ) -> Declaration { + let methods = service.methods.map { + self.makeClientProtocolMethod( + for: $0, + in: service, + from: codeGenerationRequest, + generateSerializerDeserializer: false + ) + } + + let clientProtocol = Declaration.protocol( + ProtocolDescription( + name: "\(service.namespacedPrefix)ClientProtocol", + conformances: ["Sendable"], + members: methods + ) + ) + return clientProtocol + } + + private func makeExtensionProtocol( + for service: CodeGenerationRequest.ServiceDescriptor, + in codeGenerationRequest: CodeGenerationRequest + ) -> Declaration { + let methods = service.methods.map { + self.makeClientProtocolMethod( + for: $0, + in: service, + from: codeGenerationRequest, + generateSerializerDeserializer: true + ) + } + let clientProtocolExtension = Declaration.extension( + ExtensionDescription( + onType: "\(service.namespacedTypealiasPrefix).ClientProtocol", + declarations: methods + ) + ) + return clientProtocolExtension + } + + private func makeClientProtocolMethod( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor, + from codeGenerationRequest: CodeGenerationRequest, + generateSerializerDeserializer: Bool + ) -> Declaration { + let methodParameters = self.makeParameters( + for: method, + in: service, + from: codeGenerationRequest, + generateSerializerDeserializer: generateSerializerDeserializer + ) + let functionSignature = FunctionSignatureDescription( + kind: .function( + name: method.name, + isStatic: false + ), + generics: [.member("R")], + parameters: methodParameters, + keywords: [.async, .throws], + returnType: .identifierType(.member("R")), + whereClause: WhereClause(requirements: [.conformance("R", "Sendable")]) + ) + + if generateSerializerDeserializer { + let body = self.makeSerializerDeserializerCall( + for: method, + in: service, + from: codeGenerationRequest + ) + return .function(signature: functionSignature, body: body) + } else { + return .commentable(.doc(method.documentation), .function(signature: functionSignature)) + } + } + + private func makeSerializerDeserializerCall( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor, + from codeGenerationRequest: CodeGenerationRequest + ) -> [CodeBlock] { + let functionCall = Expression.functionCall( + calledExpression: .memberAccess( + MemberAccessDescription(left: .identifierPattern("self"), right: method.name) + ), + arguments: [ + FunctionArgumentDescription(label: "request", expression: .identifierPattern("request")), + FunctionArgumentDescription( + label: "serializer", + expression: .identifierPattern( + codeGenerationRequest.lookupSerializer( + self.methodInputOutputTypealias(for: method, service: service, type: .input) + ) + ) + ), + FunctionArgumentDescription( + label: "deserializer", + expression: .identifierPattern( + codeGenerationRequest.lookupDeserializer( + self.methodInputOutputTypealias(for: method, service: service, type: .output) + ) + ) + ), + FunctionArgumentDescription(expression: .identifierPattern("body")), + ] + ) + let awaitFunctionCall = Expression.unaryKeyword(kind: .await, expression: functionCall) + let tryAwaitFunctionCall = Expression.unaryKeyword(kind: .try, expression: awaitFunctionCall) + + return [CodeBlock(item: .expression(tryAwaitFunctionCall))] + } + + private func makeParameters( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor, + from codeGenerationRequest: CodeGenerationRequest, + generateSerializerDeserializer: Bool + ) -> [ParameterDescription] { + var parameters = [ParameterDescription]() + + parameters.append(self.clientRequestParameter(for: method, in: service)) + if !generateSerializerDeserializer { + parameters.append(self.serializerParameter(for: method, in: service)) + parameters.append(self.deserializerParameter(for: method, in: service)) + } + parameters.append(self.bodyParameter(for: method, in: service)) + return parameters + } + private func clientRequestParameter( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor + ) -> ParameterDescription { + let requestType = method.isInputStreaming ? "Stream" : "Single" + let clientRequestType = ExistingTypeDescription.member(["ClientRequest", requestType]) + return ParameterDescription( + label: "request", + type: .generic( + wrapper: clientRequestType, + wrapped: .member( + self.methodInputOutputTypealias(for: method, service: service, type: .input) + ) + ) + ) + } + + private func serializerParameter( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor + ) -> ParameterDescription { + return ParameterDescription( + label: "serializer", + type: ExistingTypeDescription.some( + .generic( + wrapper: .member("MessageSerializer"), + wrapped: .member( + self.methodInputOutputTypealias(for: method, service: service, type: .input) + ) + ) + ) + ) + } + + private func deserializerParameter( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor + ) -> ParameterDescription { + return ParameterDescription( + label: "deserializer", + type: ExistingTypeDescription.some( + .generic( + wrapper: .member("MessageDeserializer"), + wrapped: .member( + self.methodInputOutputTypealias(for: method, service: service, type: .output) + ) + ) + ) + ) + } + + private func bodyParameter( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor + ) -> ParameterDescription { + let clientStreaming = method.isOutputStreaming ? "Stream" : "Single" + let closureParameterType = ExistingTypeDescription.generic( + wrapper: .member(["ClientResponse", clientStreaming]), + wrapped: .member( + self.methodInputOutputTypealias(for: method, service: service, type: .output) + ) + ) + + let bodyClosure = ClosureSignatureDescription( + parameters: [.init(type: closureParameterType)], + keywords: [.async, .throws], + returnType: .identifierType(.member("R")), + sendable: true, + escaping: true + ) + return ParameterDescription(name: "body", type: .closure(bodyClosure)) + } + + private func makeClientStruct( + for service: CodeGenerationRequest.ServiceDescriptor, + in codeGenerationRequest: CodeGenerationRequest + ) -> Declaration { + let clientProperty = Declaration.variable( + kind: .let, + left: "client", + type: .member(["GRPCCore", "GRPCClient"]) + ) + let initializer = self.makeClientVariable() + let methods = service.methods.map { + Declaration.commentable( + .doc($0.documentation), + self.makeClientMethod(for: $0, in: service, from: codeGenerationRequest) + ) + } + + return .struct( + StructDescription( + name: "\(service.namespacedPrefix)Client", + conformances: ["\(service.namespacedTypealiasPrefix).ClientProtocol"], + members: [clientProperty, initializer] + methods + ) + ) + } + + private func makeClientVariable() -> Declaration { + let initializerBody = Expression.assignment( + left: .memberAccess( + MemberAccessDescription(left: .identifierPattern("self"), right: "client") + ), + right: .identifierPattern("client") + ) + return .function( + signature: .init( + kind: .initializer, + parameters: [.init(label: "client", type: .member(["GRPCCore", "GRPCClient"]))] + ), + body: [CodeBlock(item: .expression(initializerBody))] + ) + } + + private func clientMethod( + isInputStreaming: Bool, + isOutputStreaming: Bool + ) -> String { + switch (isInputStreaming, isOutputStreaming) { + case (true, true): + return "bidirectionalStreaming" + case (true, false): + return "clientStreaming" + case (false, true): + return "serverStreaming" + case (false, false): + return "unary" + } + } + + private func makeClientMethod( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + in service: CodeGenerationRequest.ServiceDescriptor, + from codeGenerationRequest: CodeGenerationRequest + ) -> Declaration { + let parameters = self.makeParameters( + for: method, + in: service, + from: codeGenerationRequest, + generateSerializerDeserializer: false + ) + let grpcMethodName = self.clientMethod( + isInputStreaming: method.isInputStreaming, + isOutputStreaming: method.isOutputStreaming + ) + let functionCall = Expression.functionCall( + calledExpression: .memberAccess( + MemberAccessDescription(left: .identifierPattern("self.client"), right: "\(grpcMethodName)") + ), + arguments: [ + .init(label: "request", expression: .identifierPattern("request")), + .init( + label: "descriptor", + expression: .identifierPattern( + "\(service.namespacedTypealiasPrefix).Methods.\(method.name).descriptor" + ) + ), + .init(label: "serializer", expression: .identifierPattern("serializer")), + .init(label: "deserializer", expression: .identifierPattern("deserializer")), + .init(label: "handler", expression: .identifierPattern("body")), + ] + ) + let body = UnaryKeywordDescription( + kind: .try, + expression: .unaryKeyword(kind: .await, expression: functionCall) + ) + + return .function( + kind: .function( + name: "\(method.name)", + isStatic: false + ), + generics: [.member("R")], + parameters: parameters, + keywords: [.async, .throws], + returnType: .identifierType(.member("R")), + whereClause: WhereClause(requirements: [.conformance("R", "Sendable")]), + body: [.expression(.unaryKeyword(body))] + ) + } + + fileprivate enum InputOutputType { + case input + case output + } + + /// Generates the fully qualified name of the typealias for the input or output type of a method. + private func methodInputOutputTypealias( + for method: CodeGenerationRequest.ServiceDescriptor.MethodDescriptor, + service: CodeGenerationRequest.ServiceDescriptor, + type: InputOutputType + ) -> String { + var components: String = "\(service.namespacedTypealiasPrefix).Methods.\(method.name)" + + switch type { + case .input: + components.append(".Input") + case .output: + components.append(".Output") + } + + return components + } +} diff --git a/Sources/GRPCCodeGen/Internal/Translator/IDLToStructuredSwiftTranslator.swift b/Sources/GRPCCodeGen/Internal/Translator/IDLToStructuredSwiftTranslator.swift index 85fadaeda..0a707fd31 100644 --- a/Sources/GRPCCodeGen/Internal/Translator/IDLToStructuredSwiftTranslator.swift +++ b/Sources/GRPCCodeGen/Internal/Translator/IDLToStructuredSwiftTranslator.swift @@ -15,8 +15,6 @@ */ struct IDLToStructuredSwiftTranslator: Translator { - private let serverCodeTranslator = ServerCodeTranslator() - func translate( codeGenerationRequest: CodeGenerationRequest, client: Bool, @@ -28,14 +26,23 @@ struct IDLToStructuredSwiftTranslator: Translator { let imports: [ImportDescription] = [ ImportDescription(moduleName: "GRPCCore") ] + var codeBlocks: [CodeBlock] = [] codeBlocks.append( contentsOf: try typealiasTranslator.translate(from: codeGenerationRequest) ) if server { + let serverCodeTranslator = ServerCodeTranslator() + codeBlocks.append( + contentsOf: try serverCodeTranslator.translate(from: codeGenerationRequest) + ) + } + + if client { + let clientCodeTranslator = ClientCodeTranslator() codeBlocks.append( - contentsOf: try self.serverCodeTranslator.translate(from: codeGenerationRequest) + contentsOf: try clientCodeTranslator.translate(from: codeGenerationRequest) ) } diff --git a/Tests/GRPCCodeGenTests/Internal/Renderer/TextBasedRendererTests.swift b/Tests/GRPCCodeGenTests/Internal/Renderer/TextBasedRendererTests.swift index bbe99de1c..0b822be07 100644 --- a/Tests/GRPCCodeGenTests/Internal/Renderer/TextBasedRendererTests.swift +++ b/Tests/GRPCCodeGenTests/Internal/Renderer/TextBasedRendererTests.swift @@ -381,6 +381,39 @@ final class Test_TextBasedRenderer: XCTestCase { ) } + func testGenericFunction() throws { + try _test( + .init( + accessModifier: .public, + kind: .function(name: "f"), + generics: [.member("R")], + parameters: [], + whereClause: WhereClause(requirements: [.conformance("R", "Sendable")]), + body: [] + ), + renderedBy: TextBasedRenderer.renderFunction, + rendersAs: #""" + public func f() where R: Sendable {} + """# + ) + try _test( + .init( + accessModifier: .public, + kind: .function(name: "f"), + generics: [.member("R"), .member("T")], + parameters: [], + whereClause: WhereClause(requirements: [ + .conformance("R", "Sendable"), .conformance("T", "Encodable"), + ]), + body: [] + ), + renderedBy: TextBasedRenderer.renderFunction, + rendersAs: #""" + public func f() where R: Sendable, T: Encodable {} + """# + ) + } + func testFunction() throws { try _test( .init(accessModifier: .public, kind: .function(name: "f"), parameters: [], body: []), @@ -436,7 +469,7 @@ final class Test_TextBasedRenderer: XCTestCase { func testIdentifiers() throws { try _test( .pattern("foo"), - renderedBy: TextBasedRenderer.renderedIdentifier, + renderedBy: TextBasedRenderer.renderIdentifier, rendersAs: #""" foo """# diff --git a/Tests/GRPCCodeGenTests/Internal/Translator/ClientCodeTranslatorSnippetBasedTests.swift b/Tests/GRPCCodeGenTests/Internal/Translator/ClientCodeTranslatorSnippetBasedTests.swift new file mode 100644 index 000000000..144749390 --- /dev/null +++ b/Tests/GRPCCodeGenTests/Internal/Translator/ClientCodeTranslatorSnippetBasedTests.swift @@ -0,0 +1,542 @@ +/* + * Copyright 2023, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import XCTest + +@testable import GRPCCodeGen + +final class ClientCodeTranslatorSnippetBasedTests: XCTestCase { + typealias MethodDescriptor = GRPCCodeGen.CodeGenerationRequest.ServiceDescriptor.MethodDescriptor + typealias ServiceDescriptor = GRPCCodeGen.CodeGenerationRequest.ServiceDescriptor + + func testClientCodeTranslatorUnaryMethod() throws { + let method = MethodDescriptor( + documentation: "Documentation for MethodA", + name: "methodA", + isInputStreaming: false, + isOutputStreaming: false, + inputType: "NamespaceA_ServiceARequest", + outputType: "NamespaceA_ServiceAResponse" + ) + let service = ServiceDescriptor( + documentation: "Documentation for ServiceA", + name: "ServiceA", + namespace: "namespaceA", + methods: [method] + ) + let expectedSwift = + """ + /// Documentation for ServiceA + protocol namespaceA_ServiceAClientProtocol: Sendable { + /// Documentation for MethodA + func methodA( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable + } + extension namespaceA.ServiceA.ClientProtocol { + func methodA( + request: ClientRequest.Single, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.methodA( + request: request, + serializer: ProtobufSerializer(), + deserializer: ProtobufDeserializer(), + body + ) + } + } + /// Documentation for ServiceA + struct namespaceA_ServiceAClient: namespaceA.ServiceA.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + /// Documentation for MethodA + func methodA( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.unary( + request: request, + descriptor: namespaceA.ServiceA.Methods.methodA.descriptor, + serializer: serializer, + deserializer: deserializer, + handler: body + ) + } + } + """ + + try self.assertClientCodeTranslation( + codeGenerationRequest: makeCodeGenerationRequest(services: [service]), + expectedSwift: expectedSwift + ) + } + + func testClientCodeTranslatorClientStreamingMethod() throws { + let method = MethodDescriptor( + documentation: "Documentation for MethodA", + name: "methodA", + isInputStreaming: true, + isOutputStreaming: false, + inputType: "NamespaceA_ServiceARequest", + outputType: "NamespaceA_ServiceAResponse" + ) + let service = ServiceDescriptor( + documentation: "Documentation for ServiceA", + name: "ServiceA", + namespace: "namespaceA", + methods: [method] + ) + let expectedSwift = + """ + /// Documentation for ServiceA + protocol namespaceA_ServiceAClientProtocol: Sendable { + /// Documentation for MethodA + func methodA( + request: ClientRequest.Stream, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable + } + extension namespaceA.ServiceA.ClientProtocol { + func methodA( + request: ClientRequest.Stream, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.methodA( + request: request, + serializer: ProtobufSerializer(), + deserializer: ProtobufDeserializer(), + body + ) + } + } + /// Documentation for ServiceA + struct namespaceA_ServiceAClient: namespaceA.ServiceA.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + /// Documentation for MethodA + func methodA( + request: ClientRequest.Stream, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.clientStreaming( + request: request, + descriptor: namespaceA.ServiceA.Methods.methodA.descriptor, + serializer: serializer, + deserializer: deserializer, + handler: body + ) + } + } + """ + + try self.assertClientCodeTranslation( + codeGenerationRequest: makeCodeGenerationRequest(services: [service]), + expectedSwift: expectedSwift + ) + } + + func testClientCodeTranslatorServerStreamingMethod() throws { + let method = MethodDescriptor( + documentation: "Documentation for MethodA", + name: "methodA", + isInputStreaming: false, + isOutputStreaming: true, + inputType: "NamespaceA_ServiceARequest", + outputType: "NamespaceA_ServiceAResponse" + ) + let service = ServiceDescriptor( + documentation: "Documentation for ServiceA", + name: "ServiceA", + namespace: "namespaceA", + methods: [method] + ) + let expectedSwift = + """ + /// Documentation for ServiceA + protocol namespaceA_ServiceAClientProtocol: Sendable { + /// Documentation for MethodA + func methodA( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable + } + extension namespaceA.ServiceA.ClientProtocol { + func methodA( + request: ClientRequest.Single, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable { + try await self.methodA( + request: request, + serializer: ProtobufSerializer(), + deserializer: ProtobufDeserializer(), + body + ) + } + } + /// Documentation for ServiceA + struct namespaceA_ServiceAClient: namespaceA.ServiceA.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + /// Documentation for MethodA + func methodA( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.serverStreaming( + request: request, + descriptor: namespaceA.ServiceA.Methods.methodA.descriptor, + serializer: serializer, + deserializer: deserializer, + handler: body + ) + } + } + """ + + try self.assertClientCodeTranslation( + codeGenerationRequest: makeCodeGenerationRequest(services: [service]), + expectedSwift: expectedSwift + ) + } + + func testClientCodeTranslatorBidirectionalStreamingMethod() throws { + let method = MethodDescriptor( + documentation: "Documentation for MethodA", + name: "methodA", + isInputStreaming: true, + isOutputStreaming: true, + inputType: "NamespaceA_ServiceARequest", + outputType: "NamespaceA_ServiceAResponse" + ) + let service = ServiceDescriptor( + documentation: "Documentation for ServiceA", + name: "ServiceA", + namespace: "namespaceA", + methods: [method] + ) + let expectedSwift = + """ + /// Documentation for ServiceA + protocol namespaceA_ServiceAClientProtocol: Sendable { + /// Documentation for MethodA + func methodA( + request: ClientRequest.Stream, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable + } + extension namespaceA.ServiceA.ClientProtocol { + func methodA( + request: ClientRequest.Stream, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable { + try await self.methodA( + request: request, + serializer: ProtobufSerializer(), + deserializer: ProtobufDeserializer(), + body + ) + } + } + /// Documentation for ServiceA + struct namespaceA_ServiceAClient: namespaceA.ServiceA.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + /// Documentation for MethodA + func methodA( + request: ClientRequest.Stream, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.bidirectionalStreaming( + request: request, + descriptor: namespaceA.ServiceA.Methods.methodA.descriptor, + serializer: serializer, + deserializer: deserializer, + handler: body + ) + } + } + """ + + try self.assertClientCodeTranslation( + codeGenerationRequest: makeCodeGenerationRequest(services: [service]), + expectedSwift: expectedSwift + ) + } + + func testClientCodeTranslatorMultipleMethods() throws { + let methodA = MethodDescriptor( + documentation: "Documentation for MethodA", + name: "methodA", + isInputStreaming: true, + isOutputStreaming: false, + inputType: "NamespaceA_ServiceARequest", + outputType: "NamespaceA_ServiceAResponse" + ) + let methodB = MethodDescriptor( + documentation: "Documentation for MethodB", + name: "methodB", + isInputStreaming: false, + isOutputStreaming: true, + inputType: "NamespaceA_ServiceARequest", + outputType: "NamespaceA_ServiceAResponse" + ) + let service = ServiceDescriptor( + documentation: "Documentation for ServiceA", + name: "ServiceA", + namespace: "namespaceA", + methods: [methodA, methodB] + ) + let expectedSwift = + """ + /// Documentation for ServiceA + protocol namespaceA_ServiceAClientProtocol: Sendable { + /// Documentation for MethodA + func methodA( + request: ClientRequest.Stream, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable + /// Documentation for MethodB + func methodB( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable + } + extension namespaceA.ServiceA.ClientProtocol { + func methodA( + request: ClientRequest.Stream, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.methodA( + request: request, + serializer: ProtobufSerializer(), + deserializer: ProtobufDeserializer(), + body + ) + } + func methodB( + request: ClientRequest.Single, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable { + try await self.methodB( + request: request, + serializer: ProtobufSerializer(), + deserializer: ProtobufDeserializer(), + body + ) + } + } + /// Documentation for ServiceA + struct namespaceA_ServiceAClient: namespaceA.ServiceA.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + /// Documentation for MethodA + func methodA( + request: ClientRequest.Stream, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.clientStreaming( + request: request, + descriptor: namespaceA.ServiceA.Methods.methodA.descriptor, + serializer: serializer, + deserializer: deserializer, + handler: body + ) + } + /// Documentation for MethodB + func methodB( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Stream) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.serverStreaming( + request: request, + descriptor: namespaceA.ServiceA.Methods.methodB.descriptor, + serializer: serializer, + deserializer: deserializer, + handler: body + ) + } + } + """ + + try self.assertClientCodeTranslation( + codeGenerationRequest: makeCodeGenerationRequest(services: [service]), + expectedSwift: expectedSwift + ) + } + + func testClientCodeTranslatorNoNamespaceService() throws { + let method = MethodDescriptor( + documentation: "Documentation for MethodA", + name: "methodA", + isInputStreaming: false, + isOutputStreaming: false, + inputType: "ServiceARequest", + outputType: "ServiceAResponse" + ) + let service = ServiceDescriptor( + documentation: "Documentation for ServiceA", + name: "ServiceA", + namespace: "", + methods: [method] + ) + let expectedSwift = + """ + /// Documentation for ServiceA + protocol ServiceAClientProtocol: Sendable { + /// Documentation for MethodA + func methodA( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable + } + extension ServiceA.ClientProtocol { + func methodA( + request: ClientRequest.Single, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.methodA( + request: request, + serializer: ProtobufSerializer(), + deserializer: ProtobufDeserializer(), + body + ) + } + } + /// Documentation for ServiceA + struct ServiceAClient: ServiceA.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + /// Documentation for MethodA + func methodA( + request: ClientRequest.Single, + serializer: some MessageSerializer, + deserializer: some MessageDeserializer, + _ body: @Sendable @escaping (ClientResponse.Single) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.unary( + request: request, + descriptor: ServiceA.Methods.methodA.descriptor, + serializer: serializer, + deserializer: deserializer, + handler: body + ) + } + } + """ + + try self.assertClientCodeTranslation( + codeGenerationRequest: makeCodeGenerationRequest(services: [service]), + expectedSwift: expectedSwift + ) + } + + func testClientCodeTranslatorMultipleServices() throws { + let serviceA = ServiceDescriptor( + documentation: "Documentation for ServiceA", + name: "ServiceA", + namespace: "namespaceA", + methods: [] + ) + let serviceB = ServiceDescriptor( + documentation: "Documentation for ServiceB", + name: "ServiceB", + namespace: "", + methods: [] + ) + let expectedSwift = + """ + /// Documentation for ServiceA + protocol namespaceA_ServiceAClientProtocol: Sendable {} + extension namespaceA.ServiceA.ClientProtocol { + } + /// Documentation for ServiceA + struct namespaceA_ServiceAClient: namespaceA.ServiceA.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + } + /// Documentation for ServiceB + protocol ServiceBClientProtocol: Sendable {} + extension ServiceB.ClientProtocol { + } + /// Documentation for ServiceB + struct ServiceBClient: ServiceB.ClientProtocol { + let client: GRPCCore.GRPCClient + init(client: GRPCCore.GRPCClient) { + self.client = client + } + } + """ + + try self.assertClientCodeTranslation( + codeGenerationRequest: makeCodeGenerationRequest(services: [serviceA, serviceB]), + expectedSwift: expectedSwift + ) + } + + private func assertClientCodeTranslation( + codeGenerationRequest: CodeGenerationRequest, + expectedSwift: String + ) throws { + let translator = ClientCodeTranslator() + let codeBlocks = try translator.translate(from: codeGenerationRequest) + let renderer = TextBasedRenderer.default + renderer.renderCodeBlocks(codeBlocks) + let contents = renderer.renderedContents() + try XCTAssertEqualWithDiff(contents, expectedSwift) + } +}