Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ jobs:
# We pass the list of examples here, but we can't pass an array as argument
# Instead, we pass a String with a valid JSON array.
# The workaround is mentioned here https://github.com/orgs/community/discussions/11692
examples: "[ 'api-key', 'converse', 'converse-stream', 'embeddings', 'openai', 'text_chat' ]"
examples: "[ 'api-key', 'converse', 'converse-stream', 'embeddings', 'openai', 'retrieve', 'text_chat' ]"

swift-6-language-mode:
name: Swift 6 Language Mode
Expand Down
8 changes: 8 additions & 0 deletions Examples/retrieve/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.DS_Store
/.build
/Packages
xcuserdata/
DerivedData/
.swiftpm/configuration/registries.json
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
.netrc
30 changes: 30 additions & 0 deletions Examples/retrieve/Package.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// swift-tools-version: 6.0
// The swift-tools-version declares the minimum version of Swift required to build this package.

import PackageDescription

let package = Package(
name: "Retrieve",
platforms: [.macOS(.v15), .iOS(.v18), .tvOS(.v18)],
products: [
.executable(name: "Retrieve", targets: ["Retrieve"])
],
dependencies: [
// for production use, uncomment the following line
// .package(url: "https://github.com/build-on-aws/swift-bedrock-library.git", branch: "main"),

// for local development, use the following line
.package(name: "swift-bedrock-library", path: "../.."),

.package(url: "https://github.com/apple/swift-log.git", from: "1.5.0"),
],
targets: [
.executableTarget(
name: "Retrieve",
dependencies: [
.product(name: "BedrockService", package: "swift-bedrock-library"),
.product(name: "Logging", package: "swift-log"),
]
)
]
)
82 changes: 82 additions & 0 deletions Examples/retrieve/Sources/Retrieve.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift Bedrock Library open source project
//
// Copyright (c) 2025 Amazon.com, Inc. or its affiliates
// and the Swift Bedrock Library project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import BedrockService
import Logging

@main
struct Main {
static func main() async throws {
do {
try await Main.retrieve()
} catch {
print("Error:\n\(error)")
}
}

static func retrieve() async throws {
var logger = Logger(label: "Retrieve")
logger.logLevel = .debug

let bedrock = try await BedrockService(
region: .uswest2,
logger: logger
// uncomment if you use SSO with AWS Identity Center
// authentication: .sso
)

let knowledgeBaseId = "EQ13XRVPLE"
let query = "should I write open source or open-source"
let numberOfResults = 3

print("Retrieving from knowledge base...")
print("Knowledge Base ID: \(knowledgeBaseId)")
print("Query: \(query)")
print("Number of results: \(numberOfResults)")
print()

let response = try await bedrock.retrieve(
knowledgeBaseId: knowledgeBaseId,
retrievalQuery: query,
numberOfResults: numberOfResults
)

print("Retrieved \(response.results?.count ?? 0) results:")

// Show best match using convenience function
if let bestMatch = response.bestMatch() {
print("\n--- Best Match (Score: \(bestMatch.score ?? 0)) ---")
if let content = bestMatch.content?.text {
print("Content: \(content)")
}
}

// Show all results using convenience property
// if let results = response.results {
// for (index, result) in results.enumerated() {
// print("\n--- Result \(index + 1) ---")
// if let content = result.content?.text {
// print("Content: \(content)")
// }
// if let score = result.score {
// print("Score: \(score)")
// }
// if let location = result.location?.s3Location {
// print("Source: s3://\(location.uri ?? "unknown")")
// }
// }
// }
}
}
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ let package = Package(
.product(name: "AWSClientRuntime", package: "aws-sdk-swift"),
.product(name: "AWSBedrock", package: "aws-sdk-swift"),
.product(name: "AWSBedrockRuntime", package: "aws-sdk-swift"),
.product(name: "AWSBedrockAgentRuntime", package: "aws-sdk-swift"),
.product(name: "AWSSSOOIDC", package: "aws-sdk-swift"),
.product(name: "Smithy", package: "smithy-swift"),
.product(name: "Logging", package: "swift-log"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//===----------------------------------------------------------------------===//

import AWSBedrock
import AWSBedrockAgentRuntime
import AWSBedrockRuntime
import ClientRuntime
import SmithyHTTPAuthAPI
Expand All @@ -38,3 +39,6 @@ extension BedrockClient.BedrockClientConfiguration: @retroactive @unchecked Send
extension BedrockRuntimeClient.BedrockRuntimeClientConfiguration: @retroactive @unchecked Sendable,
BedrockConfigProtocol
{}
extension BedrockAgentRuntimeClient.BedrockAgentRuntimeClientConfiguration: @retroactive @unchecked Sendable,
BedrockConfigProtocol
{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift Bedrock Library open source project
//
// Copyright (c) 2025 Amazon.com, Inc. or its affiliates
// and the Swift Bedrock Library project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

@preconcurrency import AWSBedrockAgentRuntime
import AWSClientRuntime

#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

/// Protocol for Amazon Bedrock Runtime Agent operations
///
/// This protocol allows writing mocks for unit tests and provides a clean interface
/// for knowledge base retrieval operations.
public protocol BedrockRuntimeAgentProtocol: Sendable {
/// Retrieves information from a knowledge base
/// - Parameter input: The retrieve input containing query and configuration
/// - Returns: RetrieveOutput with the retrieved results
/// - Throws: Error if the retrieval operation fails
func retrieve(input: RetrieveInput) async throws -> RetrieveOutput
}

extension BedrockAgentRuntimeClient: @retroactive @unchecked Sendable, BedrockRuntimeAgentProtocol {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift Bedrock Library open source project
//
// Copyright (c) 2025 Amazon.com, Inc. or its affiliates
// and the Swift Bedrock Library project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

@preconcurrency import AWSBedrockAgentRuntime
import Logging

extension BedrockService {
/// Creates a BedrockAgentRuntimeClient
/// - Parameters:
/// - region: The AWS region to configure the client for
/// - authentication: The authentication type to use
/// - logger: Logger instance
/// - Returns: Configured BedrockAgentRuntimeProtocol instance
/// - Throws: Error if client creation fails
internal static func createBedrockAgentRuntimeClient(
region: Region,
authentication: BedrockAuthentication,
logger: Logging.Logger
) async throws -> BedrockAgentRuntimeClient {
let config: BedrockAgentRuntimeClient.BedrockAgentRuntimeClientConfiguration = try await prepareConfig(
initialConfig: BedrockAgentRuntimeClient.BedrockAgentRuntimeClientConfiguration(region: region.rawValue),
authentication: authentication,
logger: logger
)
return BedrockAgentRuntimeClient(config: config)
}
/// Retrieves information from a knowledge base for RAG applications
///
/// This method queries an Amazon Bedrock knowledge base to retrieve relevant information
/// that can be used for Retrieval-Augmented Generation (RAG) applications.
///
/// - Parameters:
/// - knowledgeBaseId: The unique identifier of the knowledge base to query
/// - retrievalQuery: The query to search for in the knowledge base
/// - numberOfResults: The number of results to return (optional, defaults to 3)
/// - Returns: RetrieveResult containing the retrieved results with convenience methods
/// - Throws: BedrockLibraryError or other errors from the underlying service
public func retrieve(
knowledgeBaseId: String,
retrievalQuery: String,
numberOfResults: Int = 3
) async throws -> RetrieveResult {
logger.trace(
"Retrieving from knowledge base",
metadata: [
"knowledgeBaseId": .string(knowledgeBaseId),
"numberOfResults": .stringConvertible(numberOfResults),
]
)

let request = RetrieveRequest(
knowledgeBaseId: knowledgeBaseId,
retrievalQuery: retrievalQuery,
numberOfResults: numberOfResults
)

do {
let response = try await bedrockAgentRuntimeClient.retrieve(input: request.input)
logger.trace("Successfully retrieved from knowledge base")
return RetrieveResult(response)
} catch {
try handleCommonError(error, context: "retrieving from knowledge base")
}
}
}
53 changes: 53 additions & 0 deletions Sources/BedrockService/BedrockRuntimeAgent/RetrieveRequest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift Bedrock Library open source project
//
// Copyright (c) 2025 Amazon.com, Inc. or its affiliates
// and the Swift Bedrock Library project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

@preconcurrency import AWSBedrockAgentRuntime

/// A request for retrieving information from a knowledge base
public struct RetrieveRequest: Sendable {
/// The unique identifier of the knowledge base to query
public let knowledgeBaseId: String
/// The query text to search for in the knowledge base
public let retrievalQuery: String
/// The number of results to return
public let numberOfResults: Int

/// Creates a new retrieve request
/// - Parameters:
/// - knowledgeBaseId: The unique identifier of the knowledge base to query
/// - retrievalQuery: The query text to search for in the knowledge base
/// - numberOfResults: The number of results to return (defaults to 3)
public init(
knowledgeBaseId: String,
retrievalQuery: String,
numberOfResults: Int = 3
) {
self.knowledgeBaseId = knowledgeBaseId
self.retrievalQuery = retrievalQuery
self.numberOfResults = numberOfResults
}

internal var input: RetrieveInput {
RetrieveInput(
knowledgeBaseId: knowledgeBaseId,
retrievalConfiguration: BedrockAgentRuntimeClientTypes.KnowledgeBaseRetrievalConfiguration(
vectorSearchConfiguration: BedrockAgentRuntimeClientTypes.KnowledgeBaseVectorSearchConfiguration(
numberOfResults: numberOfResults
)
),
retrievalQuery: BedrockAgentRuntimeClientTypes.KnowledgeBaseQuery(text: retrievalQuery)
)
}
}
72 changes: 72 additions & 0 deletions Sources/BedrockService/BedrockRuntimeAgent/RetrieveResult.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift Bedrock Library open source project
//
// Copyright (c) 2025 Amazon.com, Inc. or its affiliates
// and the Swift Bedrock Library project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

@preconcurrency import AWSBedrockAgentRuntime

#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

/// Type alias for knowledge base retrieval result
public typealias RAGRetrievalResult = BedrockAgentRuntimeClientTypes.KnowledgeBaseRetrievalResult

internal struct SerializableResult: Codable {
let content: String?
let score: Double?
let source: String?
}

/// A wrapper around RetrieveOutput providing convenient access to retrieval results
public struct RetrieveResult: Sendable {
/// The underlying AWS SDK RetrieveOutput
public let output: RetrieveOutput

/// Creates a new RetrieveResult from a RetrieveOutput
/// - Parameter output: The AWS SDK RetrieveOutput to wrap
public init(_ output: RetrieveOutput) {
self.output = output
}

/// The retrieval results from the knowledge base query
public var results: [RAGRetrievalResult]? {
output.retrievalResults
}

/// Returns the retrieval result with the highest relevance score
/// - Returns: The best matching result, or nil if no results
public func bestMatch() -> RAGRetrievalResult? {
output.retrievalResults?.max { ($0.score ?? 0) < ($1.score ?? 0) }
}

/// Converts the retrieval results to JSON format for use with language models
/// - Returns: JSON string representation of the results
/// - Throws: Error if JSON encoding fails
public func toJSON() throws -> String {
guard let results = output.retrievalResults else { return "[]" }

let serializableResults = results.map { result in
SerializableResult(
content: result.content?.text,
score: result.score,
source: result.location?.s3Location?.uri
)
}

let jsonData = try JSONEncoder().encode(serializableResults)
return String(data: jsonData, encoding: .utf8) ?? "[]"
}
}
Loading