From 21acdcb4cc1add0e992ef5359923896c8c71b522 Mon Sep 17 00:00:00 2001 From: Josh Wright Date: Wed, 22 May 2024 15:16:38 -0700 Subject: [PATCH] declaration --- Example/main.swift | 6 +- PapyrusCore/Sources/Macros.swift | 4 +- PapyrusCore/Sources/Request.swift | 1 - PapyrusPlugin/Sources/APIMacro.swift | 390 ++++-------------- PapyrusPlugin/Sources/APIRoutesMacro.swift | 1 + PapyrusPlugin/Sources/DecoratorMacro.swift | 19 +- ...Attribute.swift => EndpointModifier.swift} | 119 ++---- PapyrusPlugin/Sources/EndpointParameter.swift | 75 ++++ PapyrusPlugin/Sources/MockMacro.swift | 45 +- .../Utilities/AttributeSyntax+Utilities.swift | 11 - .../Sources/Utilities/Declaration.swift | 219 ++++++++++ .../FunctionDeclSyntax+Utilities.swift | 113 ----- .../FunctionParameterSyntax+Utilities.swift | 7 - .../Sources/Utilities/Macro+Utilities.swift | 8 - .../ProtocolDeclSyntax+Utilities.swift | 16 - .../Sources/Utilities/String+Utilities.swift | 6 + .../Utilities/SwiftSyntax+Utilities.swift | 92 +++++ PapyrusPlugin/Tests/APIMacroTests.swift | 48 ++- 18 files changed, 577 insertions(+), 603 deletions(-) create mode 100644 PapyrusPlugin/Sources/APIRoutesMacro.swift rename PapyrusPlugin/Sources/{APIAttribute.swift => EndpointModifier.swift} (57%) create mode 100644 PapyrusPlugin/Sources/EndpointParameter.swift delete mode 100644 PapyrusPlugin/Sources/Utilities/AttributeSyntax+Utilities.swift create mode 100644 PapyrusPlugin/Sources/Utilities/Declaration.swift delete mode 100644 PapyrusPlugin/Sources/Utilities/FunctionDeclSyntax+Utilities.swift delete mode 100644 PapyrusPlugin/Sources/Utilities/FunctionParameterSyntax+Utilities.swift delete mode 100644 PapyrusPlugin/Sources/Utilities/Macro+Utilities.swift delete mode 100644 PapyrusPlugin/Sources/Utilities/ProtocolDeclSyntax+Utilities.swift create mode 100644 PapyrusPlugin/Sources/Utilities/SwiftSyntax+Utilities.swift diff --git a/Example/main.swift b/Example/main.swift index 65a9a52..052b063 100644 --- a/Example/main.swift +++ b/Example/main.swift @@ -4,8 +4,6 @@ import Papyrus @API @Mock -@KeyMapping(.snakeCase) -@Authorization(.bearer("")) protocol Sample { @GET("/todos") func getTodos() async throws -> [Todo] @@ -30,6 +28,10 @@ public struct Todo: Codable { // MARK: 1. Create a Provider with any custom configuration. let provider = Provider(baseURL: "http://127.0.0.1:3000") + .modifyRequests { + $0.addAuthorization(.bearer("")) + $0.keyMapping = .snakeCase + } .intercept { req, next in let start = Date() let res = try await next(req) diff --git a/PapyrusCore/Sources/Macros.swift b/PapyrusCore/Sources/Macros.swift index bfc067f..37b7934 100644 --- a/PapyrusCore/Sources/Macros.swift +++ b/PapyrusCore/Sources/Macros.swift @@ -3,10 +3,10 @@ import Foundation // MARK: Protocol attributes @attached(peer, names: suffixed(API)) -public macro API(_ typeName: String? = nil) = #externalMacro(module: "PapyrusPlugin", type: "APIMacro") +public macro API() = #externalMacro(module: "PapyrusPlugin", type: "APIMacro") @attached(peer, names: suffixed(Mock)) -public macro Mock(_ typeName: String? = nil) = #externalMacro(module: "PapyrusPlugin", type: "MockMacro") +public macro Mock() = #externalMacro(module: "PapyrusPlugin", type: "MockMacro") // MARK: Function or Protocol attributes diff --git a/PapyrusCore/Sources/Request.swift b/PapyrusCore/Sources/Request.swift index c93392e..4f0ecdf 100644 --- a/PapyrusCore/Sources/Request.swift +++ b/PapyrusCore/Sources/Request.swift @@ -7,7 +7,6 @@ public protocol Request { var body: Data? { get set } } - public extension Request { /// Create a cURL command from this instance /// diff --git a/PapyrusPlugin/Sources/APIMacro.swift b/PapyrusPlugin/Sources/APIMacro.swift index 1f82271..984a8d4 100644 --- a/PapyrusPlugin/Sources/APIMacro.swift +++ b/PapyrusPlugin/Sources/APIMacro.swift @@ -1,353 +1,123 @@ -import Foundation import SwiftSyntax import SwiftSyntaxMacros public struct APIMacro: PeerMacro { - public static func expansion(of node: AttributeSyntax, - providingPeersOf declaration: some DeclSyntaxProtocol, - in context: some MacroExpansionContext) throws -> [DeclSyntax] { - try handleError { - guard let type = declaration.as(ProtocolDeclSyntax.self) else { - throw PapyrusPluginError("@API can only be applied to protocols.") - } - - let name = node.firstArgument ?? "\(type.typeName)\(node.attributeName)" - return try type.createAPI(named: name) + public static func expansion( + of node: AttributeSyntax, + providingPeersOf declaration: some DeclSyntaxProtocol, + in context: some MacroExpansionContext + ) throws -> [DeclSyntax] { + guard let proto = declaration.as(ProtocolDeclSyntax.self) else { + throw PapyrusPluginError("@API can only be applied to protocols.") } + + return [ + try proto + .liveService(named: "\(proto.protocolName)\(node.attributeName)") + .declSyntax() + ] } } extension ProtocolDeclSyntax { - func createAPI(named apiName: String) throws -> String { - """ - \(access)struct \(apiName): \(typeName) { - private let provider: PapyrusCore.Provider - - \(access)init(provider: PapyrusCore.Provider) { - self.provider = provider - } - - \(try generateAPIFunctions()) - } - """ - } - - private func generateAPIFunctions() throws -> String { - var functions = try functions - .map { try $0.apiFunction() } - .map { access + $0 } - functions.append(newRequestFunction) - return functions.joined(separator: "\n\n") - } + func liveService(named name: String) throws -> Declaration { + try Declaration("\(access)struct \(name): \(protocolName)") { - private var newRequestFunction: String { - let globalBuilderStatements = apiAttributes.compactMap { $0.apiBuilderStatement() } - let content = globalBuilderStatements.isEmpty - ? """ - provider.newBuilder(method: method, path: path) - """ - : """ - var req = provider.newBuilder(method: method, path: path) - \(globalBuilderStatements.joined(separator: "\n")) - return req - """ - return """ - private func builder(method: String, path: String) -> RequestBuilder { - \(content) - } - """ - } - - private var apiAttributes: [APIAttribute] { - attributes - .compactMap { $0.as(AttributeSyntax.self) } - .compactMap(APIAttribute.init) - } -} + // 0. provider reference & init -extension FunctionDeclSyntax { - fileprivate func apiFunction() throws -> String { - let (method, path) = try apiMethodAndPath() - try validateSignature() + "private let provider: PapyrusCore.Provider" - let pathParameters = path.components(separatedBy: "/") - .compactMap { component in - if component.hasPrefix(":") { - return String(component.dropFirst()) - } else if component.hasPrefix("{") && component.hasSuffix("}") { - return String(component.dropFirst().dropLast()) - } else { - return nil - } + Declaration("init(provider: PapyrusCore.Provider)") { + "self.provider = provider" } - let attributes = parameters.compactMap({ $0.apiAttribute(httpMethod: method, pathParameters: pathParameters) }) - try validateAttributes(attributes) - - var buildRequest = """ - var req = builder(method: "\(method)", path: "\(path)") - """ - - for statement in apiAttributes.compactMap({ $0.apiBuilderStatement() }) { - buildRequest.append("\n" + statement) - } - - for statement in try parameters.compactMap({ try $0.apiBuilderStatement(httpMethod: method, pathParameters: pathParameters) }) { - buildRequest.append("\n" + statement) - } + // 1. live endpoint implementations - return """ - func \(functionName)\(signature) { - \(buildRequest) - \(try handleResponse()) + for function in functions { + try function.liveEndpointFunction(access: access) } - """ - } - private func handleResponse() throws -> String { - switch style { - case .completionHandler: - guard let callbackName else { - throw PapyrusPluginError("No callback found!") - } + // 2. builder used by all live endpoint functions - if returnResponseOnly { - return """ - provider.request(&req) { res in - \(callbackName)(res) - } - """ - } else { - return """ - provider.request(&req) { res in - do { - try res.validate() - \(resultExpression.map { "let res = \($0)" } ?? "") - \(callbackName)(.success(res)) - } catch { - \(callbackName)(.failure(error)) - } - } - """ - } - case .concurrency: - switch responseType { - case .type("Void"), .none: - return "try await provider.request(&req).validate()" - case .type where returnResponseOnly: - return "return try await provider.request(&req)" - case .type, .tuple: - guard let resultExpression else { - throw PapyrusPluginError("Missing result expression!") - } + Declaration("private func builder(method: String, path: String) -> RequestBuilder") { + let modifiers = protocolAttributes.compactMap { EndpointModifier($0) } + if modifiers.isEmpty { + "provider.newBuilder(method: method, path: path)" + } else { + "var req = provider.newBuilder(method: method, path: path)" - return """ - let res = try await provider.request(&req) - try res.validate() - return \(resultExpression) - """ - } - } - } + modifiers.compactMap { $0.builderStatement() } - private func apiMethodAndPath() throws -> (method: String, path: String) { - var method, path: String? - for attribute in apiAttributes { - switch attribute { - case let .http(_method, _path): - guard method == nil, path == nil else { - throw PapyrusPluginError("Only one method per function!") + "return req" } - - (method, path) = (_method, _path) - default: - continue - } - } - - guard let method, let path else { - throw PapyrusPluginError("No method or path!") - } - - return (method, path) - } - - private func validateAttributes(_ apiAttributes: [APIAttribute]) throws { - var bodies = 0, fields = 0 - for attribute in apiAttributes { - switch attribute { - case .body: - bodies += 1 - case .field: - fields += 1 - default: - continue } } - - guard fields == 0 || bodies == 0 else { - throw PapyrusPluginError("Can't have @Body and @Field!") - } - - guard bodies <= 1 else { - throw PapyrusPluginError("Can only have one @Body!") - } } +} - private var resultExpression: String? { - guard !returnResponseOnly else { - return nil +extension FunctionDeclSyntax { + fileprivate func liveEndpointFunction(access: String) throws -> Declaration { + guard effects == ["async", "throws"] else { + throw PapyrusPluginError("Function signature must have `async throws`.") } - switch responseType { - case .tuple(let array): - let elements = array - .map { element in - let expression = element.type == "Response" ? "res" : "try res.decode(\(element.type).self, using: req.responseDecoder)" - return [element.label, expression] - .compactMap { $0 } - .joined(separator: ": ") - } - return """ - ( - \(elements.joined(separator: ",\n")) - ) - """ - case .type(let string): - return "try res.decode(\(string).self, using: req.responseDecoder)" - default: - return nil - } - } + return try Declaration("\(access)func \(functionName)\(signature)") { + let modifiers = functionAttributes.compactMap { EndpointModifier($0) } + let (method, path, pathParameters) = try modifiers.parseMethodAndPath() - private var apiAttributes: [APIAttribute] { - attributes - .compactMap { $0.as(AttributeSyntax.self) } - .compactMap(APIAttribute.init) - } -} + // 0. create a request builder -extension FunctionParameterSyntax { - fileprivate func apiBuilderStatement(httpMethod: String, pathParameters: [String]) throws -> String? { - guard !isClosure else { - return nil - } + "var req = builder(method: \(method.inQuotes), path: \(path.inQuotes))" - return apiAttribute(httpMethod: httpMethod, pathParameters: pathParameters) - .apiBuilderStatement(input: variableName) - } + // 1. add function scope modifiers - func apiAttribute(httpMethod: String, pathParameters: [String]) -> APIAttribute { - if let explicitAPIAttribute { - // If user specifies the attribute, use that. - return explicitAPIAttribute - } else if pathParameters.contains(variableName) { - // If matches a path param, roll with that - return .path(key: nil) - } else if pathParameters.contains(KeyMapping.snakeCase.encode(variableName)) { - // If matches snake cased param, add that - return .path(key: KeyMapping.snakeCase.encode(variableName)) - } else if ["GET", "HEAD", "DELETE"].contains(httpMethod) { - // If method is GET, HEAD, DELETE - return .query(key: nil) - } else { - // Else field - return .field(key: nil) - } - } + modifiers + .compactMap { $0.builderStatement() } - fileprivate var explicitAPIAttribute: APIAttribute? { - switch type.as(IdentifierTypeSyntax.self)?.name.text { - case "Path": - return .path(key: nil) - case "Body": - return .body - case "Header": - return .header(key: nil) - case "Field": - return .field(key: nil) - case "Query": - return .query(key: nil) - default: - return nil - } - } + // 2. add parameters - private var isClosure: Bool { - type.as(AttributedTypeSyntax.self)?.baseType.is(FunctionTypeSyntax.self) ?? false - } -} + try parameters + .map { EndpointParameter($0, httpMethod: method, pathParameters: pathParameters) } + .validated() + .map { $0.builderStatement() } -/// Represents the mapping between your type's property names and -/// their corresponding request field key. -private enum KeyMapping { - /// Convert property names from camelCase to snake_case for field keys. - /// - /// e.g. `someGreatString` -> `some_great_string` - case snakeCase + // 3. handle the response and return - /// Encode String from camelCase to this KeyMapping strategy. - func encode(_ string: String) -> String { - switch self { - case .snakeCase: - return string.camelCaseToSnakeCase() + try responseStatement() } } -} -extension String { - /// Map camelCase to snake_case. Assumes `self` is already in - /// camelCase. Copied from `Foundation`. - /// - /// - Returns: The snake_cased version of `self`. - fileprivate func camelCaseToSnakeCase() -> String { - guard !self.isEmpty else { return self } - - var words : [Range] = [] - // The general idea of this algorithm is to split words on transition from lower to upper case, then on transition of >1 upper case characters to lowercase - // - // myProperty -> my_property - // myURLProperty -> my_url_property - // - // We assume, per Swift naming conventions, that the first character of the key is lowercase. - var wordStart = self.startIndex - var searchRange = self.index(after: wordStart).. String { + let requestAndValidate = """ + let res = try await provider.request(&req) + try res.validate() + """ + switch returnType { + case .type("Void"), .none: + return "try await provider.request(&req).validate()" + case .type where returnResponseOnly: + return "return try await provider.request(&req)" + case .type(let type): + return """ + \(requestAndValidate) + return try res.decode(\(type).self, using: req.responseDecoder) + """ + case .tuple(let types): + let values = types.map { label, type in + let label = label.map { "\($0): " } ?? "" + if type == "Response" { + return "\(label)res" + } else { + return "\(label)try res.decode(\(type).self, using: req.responseDecoder)" + } } - // Is the next lowercase letter more than 1 after the uppercase? If so, we encountered a group of uppercase letters that we should treat as its own word - let nextCharacterAfterCapital = self.index(after: upperCaseRange.lowerBound) - if lowerCaseRange.lowerBound == nextCharacterAfterCapital { - // The next character after capital is a lower case character and therefore not a word boundary. - // Continue searching for the next upper case for the boundary. - wordStart = upperCaseRange.lowerBound - } else { - // There was a range of >1 capital letters. Turn those into a word, stopping at the capital before the lower case character. - let beforeLowerIndex = self.index(before: lowerCaseRange.lowerBound) - words.append(upperCaseRange.lowerBound.. [DeclSyntax] { -// let messageID = MessageID(domain: "test", id: "papyrus") -// let message = MyDiagnostic(message: "Testing Peer!", diagnosticID: messageID, severity: .warning) -// let diagnostic = Diagnostic(node: Syntax(node), message: message) -// context.diagnose(diagnostic) - // TODO: Add some compiler safety to ensure certain attributes can't be on certain members. + static func expansion( + of node: AttributeSyntax, + providingPeersOf declaration: some DeclSyntaxProtocol, + in context: some MacroExpansionContext + ) throws -> [DeclSyntax] { return [] } } - -struct MyDiagnostic: DiagnosticMessage { - let message: String - let diagnosticID: MessageID - let severity: DiagnosticSeverity -} diff --git a/PapyrusPlugin/Sources/APIAttribute.swift b/PapyrusPlugin/Sources/EndpointModifier.swift similarity index 57% rename from PapyrusPlugin/Sources/APIAttribute.swift rename to PapyrusPlugin/Sources/EndpointModifier.swift index dd413d6..3d6f3b2 100644 --- a/PapyrusPlugin/Sources/APIAttribute.swift +++ b/PapyrusPlugin/Sources/EndpointModifier.swift @@ -1,6 +1,8 @@ import SwiftSyntax -enum APIAttribute { +/// To be parsed from protocol and function attributes. Modifies requests / +/// responses in some way. +enum EndpointModifier { /// Type or Function attributes case json(encoder: String, decoder: String) case urlForm(encoder: String) @@ -13,14 +15,7 @@ enum APIAttribute { /// Function attributes case http(method: String, path: String) - /// Parameter attributes - case body - case field(key: String?) - case query(key: String?) - case header(key: String?) - case path(key: String?) - - init?(syntax: AttributeSyntax) { + init?(_ syntax: AttributeSyntax) { var firstArgument: String? var secondArgument: String? var labeledArguments: [String: String] = [:] @@ -49,16 +44,6 @@ enum APIAttribute { } self = .http(method: secondArgument.withoutQuotes, path: firstArgument.withoutQuotes) - case "Body": - self = .body - case "Field": - self = .field(key: firstArgument?.withoutQuotes) - case "Query": - self = .query(key: firstArgument?.withoutQuotes) - case "Header": - self = .header(key: firstArgument?.withoutQuotes) - case "Path": - self = .path(key: firstArgument?.withoutQuotes) case "Headers": guard let firstArgument else { return nil @@ -96,84 +81,54 @@ enum APIAttribute { } } - func apiBuilderStatement(input: String? = nil) -> String? { + func builderStatement() -> String? { switch self { - case .body: - guard let input else { - return "Input Required!" - } - - return """ - req.setBody(\(input)) - """ - case let .query(key): - guard let input else { - return "Input Required!" - } - - let mapParameter = key == nil ? "" : ", mapKey: false" - return """ - req.addQuery("\(key ?? input)", value: \(input)\(mapParameter)) - """ - case let .header(key): - guard let input else { - return "Input Required!" - } - - let hasCustomKey = key == nil - let convertParameter = hasCustomKey ? "" : ", convertToHeaderCase: true" - return """ - req.addHeader("\(key ?? input)", value: \(input)\(convertParameter)) - """ - case let .path(key): - guard let input else { - return "Input Required!" - } - - return """ - req.addParameter("\(key ?? input)", value: \(input)) - """ - case let .field(key): - guard let input else { - return "Input Required!" - } - - let mapParameter = key == nil ? "" : ", mapKey: false" - return """ - req.addField("\(key ?? input)", value: \(input)\(mapParameter)) - """ case .json(let encoder, let decoder): - return """ + """ req.requestEncoder = .json(\(encoder)) req.responseDecoder = .json(\(decoder)) """ case .urlForm(let encoder): - return """ - req.requestEncoder = .urlForm(\(encoder)) - """ + "req.requestEncoder = .urlForm(\(encoder))" case .multipart(let encoder): - return """ - req.requestEncoder = .multipart(\(encoder)) - """ + "req.requestEncoder = .multipart(\(encoder))" case .converter(let encoder, let decoder): - return """ + """ req.requestEncoder = \(encoder) req.responseDecoder = \(decoder) """ case .headers(let value): - return """ - req.addHeaders(\(value)) - """ + "req.addHeaders(\(value))" case .keyMapping(let value): - return """ - req.keyMapping = \(value) - """ + "req.keyMapping = \(value)" case .authorization(value: let value): - return """ - req.addAuthorization(\(value)) - """ + "req.addAuthorization(\(value))" case .http: - return nil + nil } } } + +extension [EndpointModifier] { + func parseMethodAndPath() throws -> (method: String, path: String, parameters: [String]) { + guard let (method, path) = compactMap({ + if case let .http(method, path) = $0 { return (method, path) } + else { return nil } + }).first else { + throw PapyrusPluginError("No method or path!") + } + + let parameters = path.components(separatedBy: "/") + .compactMap { component in + if component.hasPrefix(":") { + return String(component.dropFirst()) + } else if component.hasPrefix("{") && component.hasSuffix("}") { + return String(component.dropFirst().dropLast()) + } else { + return nil + } + } + + return (method, path, parameters) + } +} diff --git a/PapyrusPlugin/Sources/EndpointParameter.swift b/PapyrusPlugin/Sources/EndpointParameter.swift new file mode 100644 index 0000000..21bec81 --- /dev/null +++ b/PapyrusPlugin/Sources/EndpointParameter.swift @@ -0,0 +1,75 @@ +import SwiftSyntax + +/// To be parsed from function parameters; indicates parts of the request. +enum EndpointParameter { + case body(name: String) + case field(name: String) + case query(name: String) + case header(name: String) + case path(name: String) + + init(_ syntax: FunctionParameterSyntax, httpMethod: String, pathParameters: [String]) { + let typeName = syntax.type.as(IdentifierTypeSyntax.self)?.name.text + let explicitAPIAttribute: EndpointParameter? = switch typeName { + case "Path": .path(name: syntax.name) + case "Body": .body(name: syntax.name) + case "Header": .header(name: syntax.name) + case "Field": .field(name: syntax.name) + case "Query": .query(name: syntax.name) + default: nil + } + + if let explicitAPIAttribute { + // If the attribute is specified, use that. + self = explicitAPIAttribute + } else if pathParameters.contains(syntax.name) { + // If matches a path param, roll with that + self = .path(name: syntax.name) + } else if ["GET", "HEAD", "DELETE"].contains(httpMethod) { + // If method is GET, HEAD, DELETE + self = .query(name: syntax.name) + } else { + // Else field + self = .field(name: syntax.name) + } + } + + func builderStatement() -> String { + switch self { + case .body(let name): + "req.setBody(\(name))" + case .query(let name): + "req.addQuery(\(name.inQuotes), value: \(name))" + case .header(let name): + "req.addHeader(\(name.inQuotes), value: \(name), convertToHeaderCase: true)" + case .path(let name): + "req.addParameter(\(name.inQuotes), value: \(name))" + case .field(let name): + "req.addField(\(name.inQuotes), value: \(name))" + } + } +} + +extension [EndpointParameter] { + func validated() throws -> [EndpointParameter] { + let bodies = filter { + if case .body = $0 { return true } + else { return false } + } + + let fields = filter { + if case .field = $0 { return true } + else { return false } + } + + guard fields.count == 0 || bodies.count == 0 else { + throw PapyrusPluginError("Can't have @Body and @Field!") + } + + guard bodies.count <= 1 else { + throw PapyrusPluginError("Can only have one @Body!") + } + + return self + } +} diff --git a/PapyrusPlugin/Sources/MockMacro.swift b/PapyrusPlugin/Sources/MockMacro.swift index db34744..7261d32 100644 --- a/PapyrusPlugin/Sources/MockMacro.swift +++ b/PapyrusPlugin/Sources/MockMacro.swift @@ -5,21 +5,21 @@ public struct MockMacro: PeerMacro { public static func expansion(of node: AttributeSyntax, providingPeersOf declaration: some DeclSyntaxProtocol, in context: some MacroExpansionContext) throws -> [DeclSyntax] { - try handleError { - guard let type = declaration.as(ProtocolDeclSyntax.self) else { - throw PapyrusPluginError("@Mock can only be applied to protocols.") - } - - let name = node.firstArgument ?? "\(type.typeName)\(node.attributeName)" - return try type.createMock(named: name) + guard let proto = declaration.as(ProtocolDeclSyntax.self) else { + throw PapyrusPluginError("@Mock can only be applied to protocols.") } + + let mock = try proto.createMock(named: "\(proto.protocolName)\(node.attributeName)") + return [ + DeclSyntax(stringLiteral: mock) + ] } } extension ProtocolDeclSyntax { fileprivate func createMock(named mockName: String) throws -> String { """ - \(access)final class \(mockName): \(typeName), @unchecked Sendable { + \(access)final class \(mockName): \(protocolName), @unchecked Sendable { private let notMockedError: Error private var mocks: [String: Any] @@ -43,32 +43,13 @@ extension ProtocolDeclSyntax { extension FunctionDeclSyntax { fileprivate func mockImplementation() throws -> String { - try validateSignature() - - let notFoundExpression: String - switch style { - case .concurrency: - notFoundExpression = "throw notMockedError" - case .completionHandler: - guard let callbackName else { - throw PapyrusPluginError("Missing @escaping completion handler as final function argument.") - } - - let unimplementedError = returnResponseOnly ? ".error(notMockedError)" : ".failure(notMockedError)" - notFoundExpression = """ - \(callbackName)(\(unimplementedError)) - return - """ + guard effects == ["async", "throws"] else { + throw PapyrusPluginError("Function signature must have `async throws`.") } - let mockerArguments = parameters.map(\.variableName).joined(separator: ", ") - let matchExpression: String = - switch style { - case .concurrency: - "return try await mocker(\(mockerArguments))" - case .completionHandler: - "mocker(\(mockerArguments))" - } + let notFoundExpression = "throw notMockedError" + let mockerArguments = parameters.map(\.name).joined(separator: ", ") + let matchExpression = "return try await mocker(\(mockerArguments))" return """ func \(functionName)\(signature) { diff --git a/PapyrusPlugin/Sources/Utilities/AttributeSyntax+Utilities.swift b/PapyrusPlugin/Sources/Utilities/AttributeSyntax+Utilities.swift deleted file mode 100644 index 073aa58..0000000 --- a/PapyrusPlugin/Sources/Utilities/AttributeSyntax+Utilities.swift +++ /dev/null @@ -1,11 +0,0 @@ -import SwiftSyntax - -extension AttributeSyntax { - var firstArgument: String? { - if case let .argumentList(list) = arguments { - return list.first?.expression.description.withoutQuotes - } - - return nil - } -} diff --git a/PapyrusPlugin/Sources/Utilities/Declaration.swift b/PapyrusPlugin/Sources/Utilities/Declaration.swift new file mode 100644 index 0000000..77d55f6 --- /dev/null +++ b/PapyrusPlugin/Sources/Utilities/Declaration.swift @@ -0,0 +1,219 @@ +import SwiftSyntax + +struct Declaration: ExpressibleByStringLiteral { + let text: String + /// Declarations inside a closure following `text`. + let nested: [Declaration]? + + init(stringLiteral value: String) { + self.init(value, nested: nil) + } + + init(_ text: String, nested: [Declaration]? = nil) { + self.text = text + self.nested = nested + } + + init(_ text: String, @DeclarationsBuilder nested: () throws -> [Declaration]) rethrows { + self.text = text + self.nested = try nested() + } + + func formattedString() -> String { + guard let nested else { + return text + } + + let nestedOrdered = isType ? nested.organized() : nested + let nestedFormatted = nestedOrdered + .map { declaration in + declaration + .formattedString() + .replacingOccurrences(of: "\n", with: "\n\t") + } + + let nestedText = nestedFormatted.joined(separator: "\n\t") + return """ + \(text) { + \t\(nestedText) + } + """ + // Using \t screws up macro syntax highlighting + .replacingOccurrences(of: "\t", with: " ") + } + + func declSyntax() -> DeclSyntax { + DeclSyntax(stringLiteral: formattedString()) + } +} + +extension [Declaration] { + /// Reorders declarations in the following manner: + /// + /// 1. Properties (public -> private) + /// 2. initializers (public -> private) + /// 3. functions (public -> private) + /// + /// Properties have no newlines between them, functions have a single, blank + /// newline between them. + fileprivate func organized() -> [Declaration] { + self + .sorted() + .spaced() + } + + private func sorted() -> [Declaration] { + sorted { $0.sortValue < $1.sortValue } + } + + private func spaced() -> [Declaration] { + var declarations: [Declaration] = [] + for declaration in self { + defer { declarations.append(declaration) } + + guard let last = declarations.last else { + continue + } + + if last.isType { + declarations.append(.newline) + } else if last.isProperty && !declaration.isProperty { + declarations.append(.newline) + } else if last.isFunction || last.isInit { + declarations.append(.newline) + } + } + + return declarations + } +} + +extension Declaration { + fileprivate var sortValue: Int { + if isType { + 0 + accessSortValue + } else if isProperty { + 10 + accessSortValue + } else if isInit { + 20 + accessSortValue + } else if !isStaticFunction { + 40 + accessSortValue + } else { + 50 + accessSortValue + } + } + + var accessSortValue: Int { + if text.contains("open") { + 0 + } else if text.contains("public") { + 1 + } else if text.contains("package") { + 2 + } else if text.contains("fileprivate") { + 4 + } else if text.contains("private") { + 5 + } else { + 3 // internal (either explicit or implicit) + } + } + + fileprivate var isType: Bool { + text.contains("enum") || + text.contains("struct") || + text.contains("protocol") || + text.contains("actor") || + text.contains("class") || + text.contains("typealias") + } + + fileprivate var isProperty: Bool { + text.contains("let") || text.contains("var") + } + + fileprivate var isStaticFunction: Bool { + (text.contains("static") || text.contains("class")) && isFunction + } + + fileprivate var isFunction: Bool { + text.contains("func") && text.contains("(") && text.contains(")") + } + + fileprivate var isInit: Bool { + text.contains("init(") + } +} + +extension Declaration { + static let newline: Declaration = "" +} + +@resultBuilder +struct DeclarationsBuilder { + protocol Block { + var declarations: [Declaration] { get } + } + + static func buildBlock(_ components: Block...) -> [Declaration] { + components.flatMap(\.declarations) + } + + // MARK: Declaration literals + + static func buildExpression(_ expression: Declaration) -> Declaration { + expression + } + + static func buildExpression(_ expression: [Declaration]) -> [Declaration] { + expression + } + + static func buildExpression(_ expression: [Declaration]?) -> [Declaration] { + expression ?? [] + } + + // MARK: `String` literals + + static func buildExpression(_ expression: String) -> Declaration { + Declaration(expression) + } + + static func buildExpression(_ expression: [String]) -> [Declaration] { + expression.map { Declaration($0) } + } + + // MARK: `for` + + static func buildArray(_ components: [Declaration]) -> [Declaration] { + components + } + + static func buildArray(_ components: [[Declaration]]) -> [Declaration] { + components.flatMap { $0 } + } + + // MARK: `if` + + static func buildEither(first components: [Declaration]) -> [Declaration] { + components + } + + static func buildEither(second components: [Declaration]) -> [Declaration] { + components + } + + // MARK: `Optional` + + static func buildOptional(_ component: [Declaration]?) -> [Declaration] { + component ?? [] + } +} + +extension Declaration: DeclarationsBuilder.Block { + var declarations: [Declaration] { [self] } +} + +extension [Declaration]: DeclarationsBuilder.Block { + var declarations: [Declaration] { self } +} diff --git a/PapyrusPlugin/Sources/Utilities/FunctionDeclSyntax+Utilities.swift b/PapyrusPlugin/Sources/Utilities/FunctionDeclSyntax+Utilities.swift deleted file mode 100644 index 3b858e6..0000000 --- a/PapyrusPlugin/Sources/Utilities/FunctionDeclSyntax+Utilities.swift +++ /dev/null @@ -1,113 +0,0 @@ -import Foundation -import SwiftSyntax - -extension FunctionDeclSyntax { - enum ReturnType: Equatable { - struct TupleParameter: Equatable { - let label: String? - let type: String - } - - case tuple([TupleParameter]) - case type(String) - } - - enum AsyncStyle { - case concurrency - case completionHandler - } - - // MARK: Async Style - - var style: AsyncStyle { - hasEscapingCompletion ? .completionHandler : .concurrency - } - - func validateSignature() throws { - let hasAsyncAwait = effects.contains("async") && effects.contains("throws") - guard hasEscapingCompletion || hasAsyncAwait else { - throw PapyrusPluginError("Function must either have `async throws` effects or an `@escaping` completion handler as the final argument.") - } - } - - private var hasEscapingCompletion: Bool { - guard let parameter = parameters.last, returnType == nil else { - return false - } - - let type = parameter.type.trimmedDescription - let isResult = type.hasPrefix("@escaping (Result<") && type.hasSuffix("Error>) -> Void") - let isResponse = type == "@escaping (Response) -> Void" - return isResult || isResponse - } - - // MARK: Function effects & attributes - - var functionName: String { - name.text - } - - var effects: [String] { - [signature.effectSpecifiers?.asyncSpecifier, signature.effectSpecifiers?.throwsSpecifier] - .compactMap { $0 } - .map { $0.text } - } - - var parameters: [FunctionParameterSyntax] { - signature - .parameterClause - .parameters - .compactMap { FunctionParameterSyntax($0) } - } - - // MARK: Parameter Information - - var callbackName: String? { - guard let parameter = parameters.last, style == .completionHandler else { - return nil - } - - return parameter.variableName - } - - private var callbackType: String? { - guard let parameter = parameters.last, returnType == nil else { - return nil - } - - let type = parameter.type.trimmedDescription - if type == "@escaping (Response) -> Void" { - return "Response" - } else { - return type - .replacingOccurrences(of: "@escaping (Result<", with: "") - .replacingOccurrences(of: ", Error>) -> Void", with: "") - } - } - - // MARK: Return Data - - var returnResponseOnly: Bool { - responseType == .type("Response") - } - - var responseType: ReturnType? { - if style == .completionHandler, let callbackType { - return .type(callbackType) - } - - return returnType - } - - private var returnType: ReturnType? { - guard let type = signature.returnClause?.type else { - return nil - } - - if let type = type.as(TupleTypeSyntax.self) { - return .tuple(type.elements.map { .init(label: $0.firstName?.text, type: $0.type.trimmedDescription) }) - } else { - return .type(type.trimmedDescription) - } - } -} diff --git a/PapyrusPlugin/Sources/Utilities/FunctionParameterSyntax+Utilities.swift b/PapyrusPlugin/Sources/Utilities/FunctionParameterSyntax+Utilities.swift deleted file mode 100644 index e530c38..0000000 --- a/PapyrusPlugin/Sources/Utilities/FunctionParameterSyntax+Utilities.swift +++ /dev/null @@ -1,7 +0,0 @@ -import SwiftSyntax - -extension FunctionParameterSyntax { - var variableName: String { - (secondName ?? firstName).text - } -} diff --git a/PapyrusPlugin/Sources/Utilities/Macro+Utilities.swift b/PapyrusPlugin/Sources/Utilities/Macro+Utilities.swift deleted file mode 100644 index 20b1e24..0000000 --- a/PapyrusPlugin/Sources/Utilities/Macro+Utilities.swift +++ /dev/null @@ -1,8 +0,0 @@ -import SwiftSyntax -import SwiftSyntaxMacros - -extension Macro { - static func handleError(_ closure: () throws -> String) throws -> [DeclSyntax] { - [DeclSyntax(stringLiteral: try closure())] - } -} diff --git a/PapyrusPlugin/Sources/Utilities/ProtocolDeclSyntax+Utilities.swift b/PapyrusPlugin/Sources/Utilities/ProtocolDeclSyntax+Utilities.swift deleted file mode 100644 index 657fc32..0000000 --- a/PapyrusPlugin/Sources/Utilities/ProtocolDeclSyntax+Utilities.swift +++ /dev/null @@ -1,16 +0,0 @@ -import SwiftSyntax - -extension ProtocolDeclSyntax { - var typeName: String { - name.text - } - - var access: String { - modifiers.first.map { "\($0.trimmedDescription) " } ?? "" - } - - var functions: [FunctionDeclSyntax] { - memberBlock.members - .compactMap { $0.decl.as(FunctionDeclSyntax.self) } - } -} diff --git a/PapyrusPlugin/Sources/Utilities/String+Utilities.swift b/PapyrusPlugin/Sources/Utilities/String+Utilities.swift index ccd86cc..cd00c37 100644 --- a/PapyrusPlugin/Sources/Utilities/String+Utilities.swift +++ b/PapyrusPlugin/Sources/Utilities/String+Utilities.swift @@ -4,4 +4,10 @@ extension String { var withoutQuotes: String { filter { $0 != "\"" } } + + var inQuotes: String { + """ + "\(self)" + """ + } } diff --git a/PapyrusPlugin/Sources/Utilities/SwiftSyntax+Utilities.swift b/PapyrusPlugin/Sources/Utilities/SwiftSyntax+Utilities.swift new file mode 100644 index 0000000..8a68fa4 --- /dev/null +++ b/PapyrusPlugin/Sources/Utilities/SwiftSyntax+Utilities.swift @@ -0,0 +1,92 @@ +import SwiftSyntax + +extension ProtocolDeclSyntax { + var protocolName: String { + name.text + } + + var access: String { + modifiers.first.map { "\($0.trimmedDescription) " } ?? "" + } + + var functions: [FunctionDeclSyntax] { + memberBlock + .members + .compactMap { $0.decl.as(FunctionDeclSyntax.self) } + } + + var protocolAttributes: [AttributeSyntax] { + attributes.compactMap { $0.as(AttributeSyntax.self) } + } +} + +extension FunctionDeclSyntax { + enum ReturnType { + case tuple([(label: String?, type: String)]) + case type(String) + } + + // MARK: Function effects & attributes + + var functionName: String { + name.text + } + + var effects: [String] { + [signature.effectSpecifiers?.asyncSpecifier, signature.effectSpecifiers?.throwsSpecifier] + .compactMap { $0 } + .map { $0.text } + } + + var parameters: [FunctionParameterSyntax] { + signature + .parameterClause + .parameters + .compactMap { FunctionParameterSyntax($0) } + } + + var functionAttributes: [AttributeSyntax] { + attributes.compactMap { $0.as(AttributeSyntax.self) } + } + + // MARK: Return Data + + var returnResponseOnly: Bool { + if case .type("Response") = returnType { + return true + } else { + return false + } + } + + var returnType: ReturnType? { + guard let type = signature.returnClause?.type else { + return nil + } + + if let type = type.as(TupleTypeSyntax.self) { + return .tuple( + type.elements + .map { (label: $0.firstName?.text, type: $0.type.trimmedDescription) } + ) + } else { + return .type(type.trimmedDescription) + } + } +} + +extension FunctionParameterSyntax { + var name: String { + (secondName ?? firstName).text + } +} + +extension AttributeSyntax { + var firstArgument: String? { + if case let .argumentList(list) = arguments { + return list.first?.expression.description.withoutQuotes + } + + return nil + } +} diff --git a/PapyrusPlugin/Tests/APIMacroTests.swift b/PapyrusPlugin/Tests/APIMacroTests.swift index 38afcdf..2b23213 100644 --- a/PapyrusPlugin/Tests/APIMacroTests.swift +++ b/PapyrusPlugin/Tests/APIMacroTests.swift @@ -383,14 +383,14 @@ final class APIMacroTests: XCTestCase { """ @API protocol MyService { - @GET("users/:foo/:b_ar/{baz}/{z_ip}") + @GET("users/:foo/:bAr/{baz}/{zIp}") func getUser(foo: String, bAr: String, baz: Int, zIp: Int) async throws } """ } expansion: { """ protocol MyService { - @GET("users/:foo/:b_ar/{baz}/{z_ip}") + @GET("users/:foo/:bAr/{baz}/{zIp}") func getUser(foo: String, bAr: String, baz: Int, zIp: Int) async throws } @@ -402,11 +402,11 @@ final class APIMacroTests: XCTestCase { } func getUser(foo: String, bAr: String, baz: Int, zIp: Int) async throws { - var req = builder(method: "GET", path: "users/:foo/:b_ar/{baz}/{z_ip}") + var req = builder(method: "GET", path: "users/:foo/:bAr/{baz}/{zIp}") req.addParameter("foo", value: foo) - req.addParameter("b_ar", value: bAr) + req.addParameter("bAr", value: bAr) req.addParameter("baz", value: baz) - req.addParameter("z_ip", value: zIp) + req.addParameter("zIp", value: zIp) try await provider.request(&req).validate() } @@ -417,4 +417,42 @@ final class APIMacroTests: XCTestCase { """ } } + + func testSameAccess() { + assertMacro(["API": APIMacro.self]) { + """ + @API + public protocol MyService { + @GET("name") + func getName() async throws -> String + } + """ + } expansion: { + """ + public protocol MyService { + @GET("name") + func getName() async throws -> String + } + + public struct MyServiceAPI: MyService { + private let provider: PapyrusCore.Provider + + init(provider: PapyrusCore.Provider) { + self.provider = provider + } + + public func getName() async throws -> String { + var req = builder(method: "GET", path: "name") + let res = try await provider.request(&req) + try res.validate() + return try res.decode(String.self, using: req.responseDecoder) + } + + private func builder(method: String, path: String) -> RequestBuilder { + provider.newBuilder(method: method, path: path) + } + } + """ + } + } }