From d716a8e47f9ae6007aa73b804c6bc379dcf04dbb Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Mon, 22 Sep 2025 10:16:20 +0200 Subject: [PATCH 1/2] add support for handling array types --- Package.resolved | 35 ++--- Package.swift | 4 +- .../StatementMacro.swift | 144 ++++++++---------- .../StatementMacroTests.swift | 98 +++++++++--- .../StatementTests.swift | 24 +++ 5 files changed, 178 insertions(+), 127 deletions(-) diff --git a/Package.resolved b/Package.resolved index 8ae19c9..b4d2946 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,12 +1,13 @@ { + "originHash" : "788f0ba284872cb21090a8816842241d12505f9e4a8fa2f38c60c7af74e825bb", "pins" : [ { "identity" : "postgres-nio", "kind" : "remoteSourceControl", "location" : "https://github.com/vapor/postgres-nio.git", "state" : { - "revision" : "ccb25dcc428587224633a79c0ce0430eeac3dc0f", - "version" : "1.26.2" + "revision" : "8ee6118c03501196be183b0938d2ec4478c18954", + "version" : "1.27.0" } }, { @@ -50,8 +51,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-crypto.git", "state" : { - "revision" : "176abc28e002a9952470f08745cd26fad9286776", - "version" : "3.13.3" + "revision" : "d1c6b70f7c5f19fb0b8750cb8dcdf2ea6e2d8c34", + "version" : "3.15.0" } }, { @@ -68,8 +69,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-metrics.git", "state" : { - "revision" : "4c83e1cdf4ba538ef6e43a9bbd0bcc33a0ca46e3", - "version" : "2.7.0" + "revision" : "0743a9364382629da3bf5677b46a2c4b1ce5d2a6", + "version" : "2.7.1" } }, { @@ -77,8 +78,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio.git", "state" : { - "revision" : "a5fea865badcb1c993c85b0f0e8d05a4bd2270fb", - "version" : "2.85.0" + "revision" : "1c30f0f2053b654e3d1302492124aa6d242cdba7", + "version" : "2.86.0" } }, { @@ -86,8 +87,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-ssl.git", "state" : { - "revision" : "385f5bd783ffbfff46b246a7db7be8e4f04c53bd", - "version" : "2.33.0" + "revision" : "737e550e607d82bf15bdfddf158ec61652ce836f", + "version" : "2.34.0" } }, { @@ -95,8 +96,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-transport-services.git", "state" : { - "revision" : "decfd235996bc163b44e10b8a24997a3d2104b90", - "version" : "1.25.0" + "revision" : "e645014baea2ec1c2db564410c51a656cf47c923", + "version" : "1.25.1" } }, { @@ -113,8 +114,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/swiftlang/swift-syntax.git", "state" : { - "revision" : "f99ae8aa18f0cf0d53481901f88a0991dc3bd4a2", - "version" : "601.0.1" + "revision" : "4799286537280063c85a32f09884cfbca301b1a1", + "version" : "602.0.0" } }, { @@ -122,10 +123,10 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-system.git", "state" : { - "revision" : "b63d24d465e237966c3f59f47dcac6c70fb0bca3", - "version" : "1.6.1" + "revision" : "395a77f0aa927f0ff73941d7ac35f2b46d47c9db", + "version" : "1.6.3" } } ], - "version" : 2 + "version" : 3 } diff --git a/Package.swift b/Package.swift index e7da62a..956b8ef 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version: 5.9 +// swift-tools-version: 6.0 import PackageDescription import CompilerPluginSupport @@ -12,7 +12,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/swiftlang/swift-syntax.git", "509.0.0-latest"..."601.0.1-latest"), + .package(url: "https://github.com/swiftlang/swift-syntax.git", "600.0.0-latest"..."602.0.0-latest"), .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.0.0"), ], targets: [ diff --git a/Sources/PostgresNIOMacrosPlugin/StatementMacro.swift b/Sources/PostgresNIOMacrosPlugin/StatementMacro.swift index b0156de..791fa45 100644 --- a/Sources/PostgresNIOMacrosPlugin/StatementMacro.swift +++ b/Sources/PostgresNIOMacrosPlugin/StatementMacro.swift @@ -59,8 +59,8 @@ private enum StatementMacroError: Error, DiagnosticMessage { } public struct StatementMacro: ExtensionMacro, MemberMacro { - private typealias Column = (name: String, type: TokenSyntax, isOptional: Bool, alias: String?) - private typealias Bind = (name: String, type: TokenSyntax, isOptional: Bool) + private typealias Column = (name: String, type: TypeAnnotationSyntax, expression: LabeledExprSyntax, alias: String?) + private typealias Bind = (name: String, type: TypeAnnotationSyntax, isOptional: Bool) public static func expansion( of node: AttributeSyntax, @@ -72,16 +72,7 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { guard declaration.is(StructDeclSyntax.self) else { return [] } - #if canImport(SwiftSyntax600) let protocols = protocols.map { InheritedTypeSyntax(type: $0) } - #else - let protocols = if protocols.isEmpty { - // In tests, the protocol is not added before 600, so we'll add it manually - [InheritedTypeSyntax(type: TypeSyntax(stringLiteral: "PostgresPreparedStatement"))] - } else { - protocols.map { InheritedTypeSyntax(type: $0) } - } - #endif return [ ExtensionDeclSyntax( extendedType: type, @@ -102,7 +93,6 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { public static func expansion(of node: AttributeSyntax, providingMembersOf declaration: some DeclGroupSyntax, in context: some MacroExpansionContext) throws -> [DeclSyntax] { guard declaration.is(StructDeclSyntax.self) else { - #if canImport(SwiftSyntax600) context.diagnose(Diagnostic( node: node, message: StatementMacroDiagnosticMessages.invalidDeclaration, @@ -113,13 +103,6 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { ) ]) )) - #else - context.diagnose( - Diagnostic( - node: node, - message: StatementMacroDiagnosticMessages.invalidDeclaration - )) - #endif return [] } @@ -182,7 +165,7 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { itemsBuilder: { PatternBindingSyntax( pattern: IdentifierPatternSyntax(identifier: .identifier(name)), - typeAnnotation: makeTypeSyntax(type, optional: isOptional) + typeAnnotation: type ) } ) @@ -214,39 +197,72 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { let identifier = iterator.next()! as LabeledExprSyntax // works as tuple contains at least two elements // Type can be force-unwrapped as the compiler ensures it is there. let rawType = iterator.next()!.expression.as(MemberAccessExprSyntax.self)!.base! - #if canImport(SwiftSyntax600) let label = identifier.label?.identifier?.name - #else - let label = identifier.label?.text - #endif - let type: TokenSyntax - let isOptional: Bool - if let nonOptional = rawType.as(DeclReferenceExprSyntax.self)?.baseName { - type = nonOptional - isOptional = false - } else if let optional = rawType.as(OptionalChainingExprSyntax.self)?.expression.as( - DeclReferenceExprSyntax.self)?.baseName + let isBind = label == "bind" + + let type: TypeAnnotationSyntax + enum Metadata { + case bind(isOptional: Bool) + case column(LabeledExprSyntax) + } + let metadata: Metadata + + if let nonOptionalExpression = rawType.as(DeclReferenceExprSyntax.self) { + type = TypeAnnotationSyntax(type: IdentifierTypeSyntax(name: nonOptionalExpression.baseName)) + if isBind { + metadata = .bind(isOptional: false) + } else { + metadata = .column(LabeledExprSyntax(expression: nonOptionalExpression)) + } + } else if + let optionalExpression = rawType.as(OptionalChainingExprSyntax.self), + let optional = optionalExpression.expression.as(DeclReferenceExprSyntax.self)?.baseName { - type = optional - isOptional = true - } else { + type = TypeAnnotationSyntax(type: OptionalTypeSyntax(wrappedType: IdentifierTypeSyntax(name: optional))) + if isBind { + metadata = .bind(isOptional: true) + } else { + metadata = .column(LabeledExprSyntax(expression: optionalExpression)) + } + } else if + let array = rawType.as(ArrayExprSyntax.self), + let arrayElement = array.elements.first?.expression.as(DeclReferenceExprSyntax.self)?.baseName + { + type = TypeAnnotationSyntax(type: ArrayTypeSyntax(element: IdentifierTypeSyntax(name: arrayElement))) + if isBind { + metadata = .bind(isOptional: false) + } else { + metadata = .column(LabeledExprSyntax(expression: array)) + } + } else if + let optionalArray = rawType.as(OptionalChainingExprSyntax.self), + let optionalArrayExpression = optionalArray.expression.as(ArrayExprSyntax.self), + let arrayElement = optionalArrayExpression.elements.first?.expression.as(DeclReferenceExprSyntax.self)?.baseName + { + type = TypeAnnotationSyntax(type: OptionalTypeSyntax(wrappedType: ArrayTypeSyntax(element: IdentifierTypeSyntax(name: arrayElement)))) + if isBind { + metadata = .bind(isOptional: true) + } else { + metadata = .column(LabeledExprSyntax(expression: optionalArray)) + } + } + else { throw StatementMacroError.unprocessableInterpolation( name: identifier.expression.as(StringLiteralExprSyntax.self)?.segments.first?.as( StringSegmentSyntax.self)?.content.text ?? "", - isBind: label == "bind" + isBind: isBind ) } // Same thing as with type. let name = identifier.expression.as(StringLiteralExprSyntax.self)! .segments.first!.as(StringSegmentSyntax.self)!.content.text - switch label { - case "bind": + switch metadata { + case .bind(let isOptional): return .bind((name: name, type: type, isOptional: isOptional)) - default: + case .column(let expression): let alias = iterator.next()?.expression.as(StringLiteralExprSyntax.self)? .segments.first?.as(StringSegmentSyntax.self)?.content.text - - return .column((name: name, type: type, isOptional: isOptional, alias: alias)) + return .column((name: name, type: type, expression: expression, alias: alias)) } } @@ -255,7 +271,7 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { structKeyword: .keyword(.struct, trailingTrivia: .space), name: .identifier("Row", trailingTrivia: .space), memberBlockBuilder: { - for (name, type, isOptional, alias) in columns { + for (name, type, _, alias) in columns { MemberBlockItemSyntax( decl: VariableDeclSyntax( bindingSpecifier: .keyword(.var, trailingTrivia: .space), @@ -263,7 +279,7 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { itemsBuilder: { PatternBindingSyntax( pattern: IdentifierPatternSyntax(identifier: .identifier(alias ?? name)), - typeAnnotation: makeTypeSyntax(type, optional: isOptional) + typeAnnotation: type ) } ) @@ -379,20 +395,12 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { } private static func bindingsFunctionSignature() -> FunctionSignatureSyntax { - #if canImport(SwiftSyntax600) FunctionSignatureSyntax( parameterClause: .init(parameters: []), effectSpecifiers: FunctionEffectSpecifiersSyntax( throwsClause: ThrowsClauseSyntax(throwsSpecifier: .keyword(.throws))), returnClause: ReturnClauseSyntax(type: TypeSyntax(stringLiteral: "PostgresBindings")) ) - #else - FunctionSignatureSyntax( - parameterClause: .init(parameters: []), - effectSpecifiers: FunctionEffectSpecifiersSyntax(throwsSpecifier: .keyword(.throws)), - returnClause: ReturnClauseSyntax(type: TypeSyntax(stringLiteral: "PostgresBindings")) - ) - #endif } private static func decodeRow(from columns: [Column]) -> FunctionDeclSyntax { @@ -423,8 +431,8 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { argumentsBuilder: { LabeledExprSyntax(expression: MemberAccessExprSyntax( base: TupleExprSyntax(elementsBuilder: { - for (_, column, isOptional, _) in columns { - makeTypeExpressionSyntax(for: column, optional: isOptional) + for (_, _, column, _) in columns { + column } }), declName: DeclReferenceExprSyntax(baseName: .keyword(.self)) @@ -456,7 +464,6 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { } private static func decodeRowFunctionSignature() -> FunctionSignatureSyntax { - #if canImport(SwiftSyntax600) FunctionSignatureSyntax( parameterClause: .init(parameters: [ FunctionParameterSyntax( @@ -470,37 +477,6 @@ public struct StatementMacro: ExtensionMacro, MemberMacro { ), returnClause: ReturnClauseSyntax(type: TypeSyntax(stringLiteral: "Row")) ) - #else - FunctionSignatureSyntax( - parameterClause: .init(parameters: [ - FunctionParameterSyntax( - firstName: .wildcardToken(), - secondName: .identifier("row"), - type: TypeSyntax(stringLiteral: "PostgresRow") - ) - ]), - effectSpecifiers: FunctionEffectSpecifiersSyntax( - throwsSpecifier: .keyword(.throws) - ), - returnClause: ReturnClauseSyntax(type: TypeSyntax(stringLiteral: "Row")) - ) - #endif - } - - private static func makeTypeExpressionSyntax(for type: TokenSyntax, optional: Bool) -> LabeledExprSyntax { - if optional { - LabeledExprSyntax(expression: OptionalChainingExprSyntax(expression: DeclReferenceExprSyntax(baseName: type))) - } else { - LabeledExprSyntax(expression: DeclReferenceExprSyntax(baseName: type)) - } - } - - private static func makeTypeSyntax(_ type: TokenSyntax, optional: Bool) -> TypeAnnotationSyntax { - if optional { - TypeAnnotationSyntax(type: OptionalTypeSyntax(wrappedType: IdentifierTypeSyntax(name: type))) - } else { - TypeAnnotationSyntax(type: IdentifierTypeSyntax(name: type)) - } } } diff --git a/Tests/PostgresNIOMacrosPluginTests/StatementMacroTests.swift b/Tests/PostgresNIOMacrosPluginTests/StatementMacroTests.swift index c6f86f7..5fdd10a 100644 --- a/Tests/PostgresNIOMacrosPluginTests/StatementMacroTests.swift +++ b/Tests/PostgresNIOMacrosPluginTests/StatementMacroTests.swift @@ -9,23 +9,12 @@ import XCTest #if canImport(PostgresNIOMacrosPlugin) import PostgresNIOMacrosPlugin -#if canImport(SwiftSyntax600) let testMacros: [String: MacroSpec] = [ "Statement": MacroSpec(type: StatementMacro.self, conformances: ["PostgresPreparedStatement"]), ] -#else -let testMacros: [String: Macro.Type] = [ - "Statement": StatementMacro.self, -] -#endif #endif final class StatementMacroTests: XCTestCase { - #if canImport(SwiftSyntax600) - let trailingNewline = "\n" - #else - let trailingNewline = "" - #endif func testMacro() throws { #if canImport(PostgresNIOMacrosPlugin) @@ -56,7 +45,8 @@ final class StatementMacroTests: XCTestCase { func decodeRow(_ row: PostgresRow) throws -> Row { let (id, name, age) = try row.decode((UUID, String, Int).self) return Row(id: id, name: name, age: age) - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -93,7 +83,8 @@ final class StatementMacroTests: XCTestCase { func decodeRow(_ row: PostgresRow) throws -> Row { let (id, name, age) = try row.decode((UUID, String, Int).self) return Row(id: id, name: name, age: age) - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -134,7 +125,8 @@ final class StatementMacroTests: XCTestCase { } func decodeRow(_ row: PostgresRow) throws -> Row { - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -175,7 +167,8 @@ final class StatementMacroTests: XCTestCase { func decodeRow(_ row: PostgresRow) throws -> Row { let (userID, name, age) = try row.decode((UUID, String, Int).self) return Row(userID: userID, name: name, age: age) - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -206,7 +199,8 @@ final class StatementMacroTests: XCTestCase { } func decodeRow(_ row: PostgresRow) throws -> Row { - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -237,7 +231,8 @@ final class StatementMacroTests: XCTestCase { } func decodeRow(_ row: PostgresRow) throws -> Row { - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -251,11 +246,7 @@ final class StatementMacroTests: XCTestCase { func testMacroOnClassDoesNotWork() throws { #if canImport(PostgresNIOMacrosPlugin) - #if canImport(SwiftSyntax600) let fixIts = [FixItSpec(message: "Replace 'class' with 'struct'")] - #else - let fixIts: [FixItSpec] = [] - #endif assertMacroExpansion( #"@Statement("") class MyStatement {}"#, expandedSource: "class MyStatement {}", @@ -307,7 +298,8 @@ final class StatementMacroTests: XCTestCase { func decodeRow(_ row: PostgresRow) throws -> Row { let (id, name, age) = try row.decode((UUID, String, Int).self) return Row(id: id, name: name, age: age) - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -352,7 +344,8 @@ final class StatementMacroTests: XCTestCase { func decodeRow(_ row: PostgresRow) throws -> Row { let (id, name, age) = try row.decode((UUID?, String, Int).self) return Row(id: id, name: name, age: age) - }\(trailingNewline)} + } + } extension MyStatement: PostgresPreparedStatement { } @@ -428,7 +421,64 @@ final class StatementMacroTests: XCTestCase { func decodeRow(_ row: PostgresRow) throws -> Row { let (id, name, age) = try row.decode((UUID, String, Int).self) return Row(id: id, name: name, age: age) - }\#(trailingNewline)} + } + } + + extension MyStatement: PostgresPreparedStatement { + } + """#, + macroSpecs: testMacros + ) + #else + throw XCTSkip("macros are only supported when running tests for the host platform") + #endif + } + + func testEncodableArrayMacro() throws { + #if canImport(PostgresNIOMacrosPlugin) + assertMacroExpansion( + #""" + @Statement(""" + SELECT \("names", [String].self), \("memberOf", [Int]?.self) + WHERE names = \(bind: "names", [String].self) OR memberOf = \(bind: "memberOf", [Int]?.self) + FROM groups + """) + struct MyStatement {} + """#, + expandedSource: #""" + struct MyStatement { + + struct Row { + var names: [String] + var memberOf: [Int]? + } + + static let sql = """ + SELECT names, memberOf + WHERE names = $1 OR memberOf = $2 + FROM groups + """ + + var names: [String] + + var memberOf: [Int]? + + func makeBindings() throws -> PostgresBindings { + var bindings = PostgresBindings(capacity: 2) + bindings.append(names) + if let memberOf { + bindings.append(memberOf) + } else { + bindings.appendNull() + } + return bindings + } + + func decodeRow(_ row: PostgresRow) throws -> Row { + let (names, memberOf) = try row.decode(([String], [Int]?).self) + return Row(names: names, memberOf: memberOf) + } + } extension MyStatement: PostgresPreparedStatement { } diff --git a/Tests/PostgresNIOMacrosTests/StatementTests.swift b/Tests/PostgresNIOMacrosTests/StatementTests.swift index 744714c..de0ebdf 100644 --- a/Tests/PostgresNIOMacrosTests/StatementTests.swift +++ b/Tests/PostgresNIOMacrosTests/StatementTests.swift @@ -61,6 +61,24 @@ final class StatementTests { #expect(stream4Count == 0) } } + + @Test + func arraySelects() async throws { + do { + try await self.client.withConnection { connection in + let stream1 = try await connection.execute(ArraySelect(exact: ["1"]), logger: logger) + var stream1Count = 0 + for try await row in stream1 { + #expect(row.value == ["1", "2", "3", "4"]) + stream1Count += 1 + } + #expect(stream1Count == 1) + } + } catch { + print(String(reflecting: error)) + throw error + } + } } @Statement("SELECT \("1", Int.self, as: "count")") @@ -75,6 +93,12 @@ private struct SimpleNullSelect {} @Statement("SELECT \("1", Int.self, as: "count") WHERE \(bind: "minCount", Int?.self) != 0") private struct SimpleSelectWithOptionalWhereClause {} +@Statement(""" +SELECT \("string_to_array('1,2,3,4', ',')", [String].self, as: "value") +WHERE \(bind: "exact", [String].self) = string_to_array('1', '') +""") +private struct ArraySelect {} + func env(_ name: String) -> String? { getenv(name).flatMap { String(cString: $0) } } From 78b566d3d41925b07bfc4df411fdc81a939a7b8b Mon Sep 17 00:00:00 2001 From: Timo <38291523+lovetodream@users.noreply.github.com> Date: Mon, 22 Sep 2025 10:29:02 +0200 Subject: [PATCH 2/2] cover all supported swift versions in tests Signed-off-by: Timo <38291523+lovetodream@users.noreply.github.com> --- .github/workflows/test-latest.yml | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-latest.yml b/.github/workflows/test-latest.yml index 2454379..3be8025 100644 --- a/.github/workflows/test-latest.yml +++ b/.github/workflows/test-latest.yml @@ -8,8 +8,15 @@ on: jobs: test: + strategy: + fail-fast: false + matrix: + swift-image: + - swift:6.0-jammy + - swift:6.1-noble + - swift:6.2-noble container: - image: swift:jammy + image: ${{ matrix.swift-image }} services: postgres: image: postgres @@ -29,10 +36,27 @@ jobs: if: runner.debug == '1' run: | echo "LOG_LEVEL=trace" >> "$GITHUB_ENV" + - name: Install zstd + run: | + apt-get update -y + apt-get install -y zstd + - name: Restore .build + id: "restore-build" + uses: actions/cache/restore@v4 + with: + path: .build + key: "swiftpm-tests-build-${{ runner.os }}-${{ github.event.pull_request.base.sha || github.event.after }}" + restore-keys: "swiftpm-tests-build-${{ runner.os }}-" - name: Build - run: swift build + run: swift build --build-tests --enable-code-coverage + - name: Cache .build + if: steps.restore-build.outputs.cache-hit != 'true' + uses: actions/cache/save@v4 + with: + path: .build + key: "swiftpm-tests-build-${{ runner.os }}-${{ github.event.pull_request.base.sha || github.event.after }}" - name: Run tests - run: swift test --enable-code-coverage + run: swift test --skip-build --enable-code-coverage env: PSQL_HOST: postgres PSQL_PORT: 5432