Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
changelog:
exclude:
labels:
- semver-noop
categories:
- title: SemVer Major
labels:
- semver-major
- title: SemVer Minor
labels:
- semver-minor
- title: SemVer Patch
labels:
- semver-patch
- title: Other Changes
labels:
- "*"
24 changes: 15 additions & 9 deletions Sources/PostgresNIOMacrosPlugin/StatementMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,29 @@ public struct StatementMacro: ExtensionMacro, MemberMacro {
}

// It is fine to force unwrap here, because the compiler ensures we receive this exact syntax tree here.
let elements = node
let unparsedString = node
.arguments!.as(LabeledExprListSyntax.self)!
.first!.expression.as(StringLiteralExprSyntax.self)!.segments
.first!.expression.as(StringLiteralExprSyntax.self)!

var sql = ""
var sql: StringLiteralSegmentListSyntax = []
var columns: [Column] = []
var binds: [Bind] = []
for element in elements {
for element in unparsedString.segments {
if let expression = element.as(ExpressionSegmentSyntax.self) {
let interpolation = try extractInterpolations(expression)
switch interpolation {
case .column(let column):
columns.append(column)
sql.append(column.name)
sql.append(.init(StringSegmentSyntax(content: .stringSegment(column.name))))
if let alias = column.alias {
sql.append(" AS \(alias)")
sql.append(.init(StringSegmentSyntax(content: .stringSegment(" AS \(alias)"))))
}
case .bind(let bind):
binds.append(bind)
sql.append("$\(binds.count)")
sql.append(.init(StringSegmentSyntax(content: .stringSegment("$\(binds.count)"))))
}
} else if let expression = element.as(StringSegmentSyntax.self) {
sql.append(expression.content.text)
sql.append(.init(expression))
}
}

Expand All @@ -156,7 +156,13 @@ public struct StatementMacro: ExtensionMacro, MemberMacro {
) {
PatternBindingSyntax(
pattern: IdentifierPatternSyntax(identifier: .identifier("sql")),
initializer: InitializerClauseSyntax(value: StringLiteralExprSyntax(content: sql))
initializer: InitializerClauseSyntax(
value: StringLiteralExprSyntax(
openingQuote: unparsedString.openingQuote,
segments: sql,
closingQuote: unparsedString.closingQuote
)
)
)
}

Expand Down
49 changes: 49 additions & 0 deletions Tests/PostgresNIOMacrosPluginTests/StatementMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,55 @@ final class StatementMacroTests: XCTestCase {
throw XCTSkip("macros are only supported when running tests for the host platform")
#endif
}

func testMultilineMacro() throws {
#if canImport(PostgresNIOMacrosPlugin)
assertMacroExpansion(
#"""
@Statement("""
SELECT \("id", UUID.self), \("name", String.self), \("age", Int.self)
FROM users
WHERE \(bind: "age", Int.self) > age
""")
struct MyStatement {}
"""#,
expandedSource: #"""
struct MyStatement {

struct Row {
var id: UUID
var name: String
var age: Int
}

static let sql = """
SELECT id, name, age
FROM users
WHERE $1 > age
"""

var age: Int

func makeBindings() throws -> PostgresBindings {
var bindings = PostgresBindings(capacity: 1)
bindings.append(age)
return bindings
}

func decodeRow(_ row: PostgresRow) throws -> Row {
let (id, name, age) = try row.decode((UUID, String, Int).self)
return Row(id: id, name: name, age: age)
}\#(trailingNewline)}

extension MyStatement: PostgresPreparedStatement {
}
"""#,
macroSpecs: testMacros
)
#else
throw XCTSkip("macros are only supported when running tests for the host platform")
#endif
}
}


Expand Down