From 5c590540129db30a9e13987b13d47a538090e4d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Stormacq?= Date: Sat, 7 Jun 2025 20:52:10 +0200 Subject: [PATCH 1/3] support cross region inference --- .../frontend/helpers/chatModelData.js | 4 +- .../frontend/helpers/modelData.js | 4 +- .../frontend/helpers/reasoningModelData.js | 4 +- README.md | 2 +- Sources/BedrockModel.swift | 15 ++++++- .../Converse/BedrockService+Converse.swift | 2 +- .../BedrockService+ConverseStreaming.swift | 2 +- Sources/Converse/ConverseRequest.swift | 4 +- .../Converse/ConverseRequestStreaming.swift | 4 +- .../BedrockService+InvokeModelImage.swift | 41 ++++++------------- .../BedrockService+InvokeModelText.swift | 2 +- Sources/InvokeModel/InvokeModelRequest.swift | 6 +-- Sources/ListModels/ModelSummary.swift | 2 +- Sources/Modalities/CrossRegionInference.swift | 24 +++++++++++ Sources/Models/Amazon/Nova/Nova.swift | 2 +- Sources/Models/Anthropic/Anthropic.swift | 2 +- .../Anthropic/AnthropicBedrockModels.swift | 14 +++---- Sources/Models/DeepSeek/DeepSeek.swift | 2 +- .../DeepSeek/DeepSeekBedrockModels.swift | 2 +- Sources/Models/Llama/Llama.swift | 2 +- Sources/Models/Llama/LlamaBedrockModels.swift | 10 ++--- .../Models/Mistral/MistralBedrockModels.swift | 2 + Sources/Region.swift | 15 +++++++ Tests/Mock/MockBedrockClient.swift | 2 - Tests/Mock/MockBedrockRuntimeClient.swift | 19 +++++++-- 25 files changed, 116 insertions(+), 72 deletions(-) create mode 100644 Sources/Modalities/CrossRegionInference.swift diff --git a/Examples/web-playground/frontend/helpers/chatModelData.js b/Examples/web-playground/frontend/helpers/chatModelData.js index a8736cb7..acb66a20 100644 --- a/Examples/web-playground/frontend/helpers/chatModelData.js +++ b/Examples/web-playground/frontend/helpers/chatModelData.js @@ -125,7 +125,7 @@ export const chatModels = [ }, { modelName: "Anthropic Claude 3.5 Haiku", - modelId: "us.anthropic.claude-3-5-haiku-20241022-v1:0", + modelId: "anthropic.claude-3-5-haiku-20241022-v1:0", temperatureRange: { default: 1, min: 0, @@ -287,7 +287,7 @@ export const chatModels = [ // DeepSeek // { // modelName: "Deep Seek", - // modelId: "us.deepseek.r1-v1:0", + // modelId: "deepseek.r1-v1:0", // topPRange: { // max: 1, // default: 1, diff --git a/Examples/web-playground/frontend/helpers/modelData.js b/Examples/web-playground/frontend/helpers/modelData.js index f1396018..e4717147 100644 --- a/Examples/web-playground/frontend/helpers/modelData.js +++ b/Examples/web-playground/frontend/helpers/modelData.js @@ -60,7 +60,7 @@ export const models = [ }, { modelName: "Anthropic Claude 3.5 Haiku", - modelId: "us.anthropic.claude-3-5-haiku-20241022-v1:0", + modelId: "anthropic.claude-3-5-haiku-20241022-v1:0", temperatureRange: { min: 0, max: 1, @@ -173,7 +173,7 @@ export const models = [ }, // { // modelName: "Deep Seek", - // modelId: "us.deepseek.r1-v1:0", + // modelId: "deepseek.r1-v1:0", // temperatureRange: { // min: 0, // max: 1, diff --git a/Examples/web-playground/frontend/helpers/reasoningModelData.js b/Examples/web-playground/frontend/helpers/reasoningModelData.js index f2f03ccd..ef3a11f7 100644 --- a/Examples/web-playground/frontend/helpers/reasoningModelData.js +++ b/Examples/web-playground/frontend/helpers/reasoningModelData.js @@ -1,6 +1,6 @@ export const defaultModel = { modelName: "Claude V3.7 Sonnet", - modelId: "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + modelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", topKRange: { max: 500, default: 0, @@ -32,7 +32,7 @@ export const models = [ defaultModel, // { // modelName: "Deep Seek", - // modelId: "us.deepseek.r1-v1:0", + // modelId: "deepseek.r1-v1:0", // topPRange: { // max: 1, // default: 1, diff --git a/README.md b/README.md index fd90dad4..bbe25503 100644 --- a/README.md +++ b/README.md @@ -925,7 +925,7 @@ You can now create instances for any of the models that follow the request and r ```swift extension BedrockModel { public static let llama3_3_70b_instruct: BedrockModel = BedrockModel( - id: "us.meta.llama3-3-70b-instruct-v1:0", + id: "meta.llama3-3-70b-instruct-v1:0", name: "Llama 3.3 70B Instruct", modality: LlamaText( parameters: TextGenerationParameters( diff --git a/Sources/BedrockModel.swift b/Sources/BedrockModel.swift index b0474dd1..e1e7f314 100644 --- a/Sources/BedrockModel.swift +++ b/Sources/BedrockModel.swift @@ -18,8 +18,8 @@ import Foundation public struct BedrockModel: Hashable, Sendable, Equatable, RawRepresentable { public var rawValue: String { id } - public var id: String - public var name: String + public let id: String + public let name: String public let modality: any Modality /// Creates a new BedrockModel instance @@ -106,6 +106,17 @@ public struct BedrockModel: Hashable, Sendable, Equatable, RawRepresentable { } } + // MARK: Cross region inference + public func getModelIdWithCrossRegionInferencePrefix(region: Region) -> String { + // If the model does not support cross region inference, return the model ID as is + guard let crossRegionInferenceModality = modality as? CrossRegionInferenceModality else { + return id + } + // If the model supports cross region inference, return the model ID with the appropriate prefix + let prefix = crossRegionInferenceModality.crossRegionPrefix(forRegion: region) + return "\(prefix)\(id)" + } + // MARK: Modality checks // MARK - Text completion diff --git a/Sources/Converse/BedrockService+Converse.swift b/Sources/Converse/BedrockService+Converse.swift index 59862acd..e69f1459 100644 --- a/Sources/Converse/BedrockService+Converse.swift +++ b/Sources/Converse/BedrockService+Converse.swift @@ -84,7 +84,7 @@ extension BedrockService { ) logger.trace("Creating ConverseInput") - let input = try converseRequest.getConverseInput() + let input = try converseRequest.getConverseInput(forRegion: self.region) logger.trace( "Sending ConverseInput to BedrockRuntimeClient", diff --git a/Sources/Converse/BedrockService+ConverseStreaming.swift b/Sources/Converse/BedrockService+ConverseStreaming.swift index eb83c8b0..21a189b6 100644 --- a/Sources/Converse/BedrockService+ConverseStreaming.swift +++ b/Sources/Converse/BedrockService+ConverseStreaming.swift @@ -92,7 +92,7 @@ extension BedrockService { ) logger.trace("Creating ConverseStreamingInput") - let input = try converseRequest.getConverseStreamingInput() + let input = try converseRequest.getConverseStreamingInput(forRegion: region) logger.trace( "Sending ConverseStreaminInput to BedrockRuntimeClient", diff --git a/Sources/Converse/ConverseRequest.swift b/Sources/Converse/ConverseRequest.swift index 1273a888..1ae0864d 100644 --- a/Sources/Converse/ConverseRequest.swift +++ b/Sources/Converse/ConverseRequest.swift @@ -53,12 +53,12 @@ public struct ConverseRequest { } } - func getConverseInput() throws -> ConverseInput { + func getConverseInput(forRegion region: Region) throws -> ConverseInput { ConverseInput( additionalModelRequestFields: try getAdditionalModelRequestFields(), inferenceConfig: inferenceConfig?.getSDKInferenceConfig(), messages: try getSDKMessages(), - modelId: model.id, + modelId: model.getModelIdWithCrossRegionInferencePrefix(region: region), system: getSDKSystemPrompts(), toolConfig: try toolConfig?.getSDKToolConfig() ) diff --git a/Sources/Converse/ConverseRequestStreaming.swift b/Sources/Converse/ConverseRequestStreaming.swift index 55f81a6d..508ab847 100644 --- a/Sources/Converse/ConverseRequestStreaming.swift +++ b/Sources/Converse/ConverseRequestStreaming.swift @@ -17,12 +17,12 @@ public typealias ConverseStreamingRequest = ConverseRequest extension ConverseStreamingRequest { - func getConverseStreamingInput() throws -> ConverseStreamInput { + func getConverseStreamingInput(forRegion region: Region) throws -> ConverseStreamInput { ConverseStreamInput( additionalModelRequestFields: try getAdditionalModelRequestFields(), inferenceConfig: inferenceConfig?.getSDKInferenceConfig(), messages: try getSDKMessages(), - modelId: model.id, + modelId: model.getModelIdWithCrossRegionInferencePrefix(region: region), system: getSDKSystemPrompts(), toolConfig: try toolConfig?.getSDKToolConfig() ) diff --git a/Sources/InvokeModel/BedrockService+InvokeModelImage.swift b/Sources/InvokeModel/BedrockService+InvokeModelImage.swift index 6758a779..9150e960 100644 --- a/Sources/InvokeModel/BedrockService+InvokeModelImage.swift +++ b/Sources/InvokeModel/BedrockService+InvokeModelImage.swift @@ -77,31 +77,9 @@ extension BedrockService { quality: quality, resolution: resolution ) - let input: InvokeModelInput = try request.getInvokeModelInput() - logger.trace( - "Sending request to invokeModel", - metadata: [ - "model": .string(model.id), "request": .string(String(describing: input)), - ] - ) - let response = try await self.bedrockRuntimeClient.invokeModel(input: input) - guard let responseBody = response.body else { - logger.trace( - "Invalid response", - metadata: [ - "response": .string(String(describing: response)), - "hasBody": .stringConvertible(response.body != nil), - ] - ) - throw BedrockLibraryError.invalidSDKResponse( - "Something went wrong while extracting body from response." - ) - } - let invokemodelResponse: InvokeModelResponse = try InvokeModelResponse.createImageResponse( - body: responseBody, - model: model - ) - return try invokemodelResponse.getGeneratedImage() + + return try await sendRequest(request: request, model: model) + } catch { try handleCommonError(error, context: "listing foundation models") } @@ -174,7 +152,15 @@ extension BedrockService { quality: quality, resolution: resolution ) - let input: InvokeModelInput = try request.getInvokeModelInput() + return try await sendRequest(request: request, model: model) + } catch { + try handleCommonError(error, context: "listing foundation models") + } + } + + /// Sends the request to invoke the model and returns the generated image(s) + private func sendRequest(request: InvokeModelRequest, model: BedrockModel) async throws -> ImageGenerationOutput { + let input: InvokeModelInput = try request.getInvokeModelInput(forRegion: self.region) logger.trace( "Sending request to invokeModel", metadata: [ @@ -199,9 +185,6 @@ extension BedrockService { model: model ) return try invokemodelResponse.getGeneratedImage() - } catch { - try handleCommonError(error, context: "listing foundation models") - } } /// Generates 1 to 5 image variation(s) from reference images and a text prompt using a specific model diff --git a/Sources/InvokeModel/BedrockService+InvokeModelText.swift b/Sources/InvokeModel/BedrockService+InvokeModelText.swift index 786a37b2..929c35f7 100644 --- a/Sources/InvokeModel/BedrockService+InvokeModelText.swift +++ b/Sources/InvokeModel/BedrockService+InvokeModelText.swift @@ -85,7 +85,7 @@ extension BedrockService { topK: topK, stopSequences: stopSequences ) - let input: InvokeModelInput = try request.getInvokeModelInput() + let input: InvokeModelInput = try request.getInvokeModelInput(forRegion: self.region) logger.trace( "Sending request to invokeModel", metadata: [ diff --git a/Sources/InvokeModel/InvokeModelRequest.swift b/Sources/InvokeModel/InvokeModelRequest.swift index 06e19512..86868152 100644 --- a/Sources/InvokeModel/InvokeModelRequest.swift +++ b/Sources/InvokeModel/InvokeModelRequest.swift @@ -205,14 +205,14 @@ struct InvokeModelRequest { /// Creates an InvokeModelInput instance for making a request to Amazon Bedrock /// - Returns: A configured InvokeModelInput containing the model ID, content type, and encoded request body /// - Throws: BedrockLibraryError.encodingError if the request body cannot be encoded to JSON - public func getInvokeModelInput() throws -> InvokeModelInput { - do { + public func getInvokeModelInput(forRegion region: Region) throws -> InvokeModelInput { + do { let jsonData: Data = try JSONEncoder().encode(self.body) return InvokeModelInput( accept: self.accept.headerValue, body: jsonData, contentType: self.contentType.headerValue, - modelId: model.id + modelId: model.getModelIdWithCrossRegionInferencePrefix(region: region) ) } catch { throw BedrockLibraryError.encodingError( diff --git a/Sources/ListModels/ModelSummary.swift b/Sources/ListModels/ModelSummary.swift index cc6478cb..95af7cab 100644 --- a/Sources/ListModels/ModelSummary.swift +++ b/Sources/ListModels/ModelSummary.swift @@ -59,7 +59,7 @@ public struct ModelSummary: Encodable { if sdkModelSummary.responseStreamingSupported != nil { responseStreamingSupported = sdkModelSummary.responseStreamingSupported! } - let bedrockModel = BedrockModel(rawValue: modelId) ?? BedrockModel(rawValue: "us.\(modelId)") + let bedrockModel = BedrockModel(rawValue: modelId) return ModelSummary( modelName: modelName, diff --git a/Sources/Modalities/CrossRegionInference.swift b/Sources/Modalities/CrossRegionInference.swift new file mode 100644 index 00000000..b869839d --- /dev/null +++ b/Sources/Modalities/CrossRegionInference.swift @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Bedrock Library open source project +// +// Copyright (c) 2025 Amazon.com, Inc. or its affiliates +// and the Swift Bedrock Library project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +public protocol CrossRegionInferenceModality: Sendable {} +extension CrossRegionInferenceModality { + public func crossRegionPrefix(forRegion region: Region) -> String { + if region.isEURegion() { return "eu." } + if region.isUSRegion() { return "us." } + if region.isAPRegion() { return "ap." } + return "" + } +} diff --git a/Sources/Models/Amazon/Nova/Nova.swift b/Sources/Models/Amazon/Nova/Nova.swift index 26625adf..0af45a3e 100644 --- a/Sources/Models/Amazon/Nova/Nova.swift +++ b/Sources/Models/Amazon/Nova/Nova.swift @@ -15,7 +15,7 @@ import Foundation -struct NovaText: TextModality, ConverseModality, ConverseStreamingModality { +struct NovaText: TextModality, ConverseModality, ConverseStreamingModality, CrossRegionInferenceModality { func getName() -> String { "Nova Text Generation" } let parameters: TextGenerationParameters diff --git a/Sources/Models/Anthropic/Anthropic.swift b/Sources/Models/Anthropic/Anthropic.swift index 5587bda9..548b498c 100644 --- a/Sources/Models/Anthropic/Anthropic.swift +++ b/Sources/Models/Anthropic/Anthropic.swift @@ -15,7 +15,7 @@ import Foundation -struct AnthropicText: TextModality, ConverseModality, ConverseStreamingModality { +struct AnthropicText: TextModality, ConverseModality, ConverseStreamingModality, CrossRegionInferenceModality { let parameters: TextGenerationParameters let converseParameters: ConverseParameters let converseFeatures: [ConverseFeature] diff --git a/Sources/Models/Anthropic/AnthropicBedrockModels.swift b/Sources/Models/Anthropic/AnthropicBedrockModels.swift index 497c231d..8cf1a8bc 100644 --- a/Sources/Models/Anthropic/AnthropicBedrockModels.swift +++ b/Sources/Models/Anthropic/AnthropicBedrockModels.swift @@ -93,7 +93,7 @@ extension BedrockModel { ) ) public static let claudev3_opus: BedrockModel = BedrockModel( - id: "us.anthropic.claude-3-opus-20240229-v1:0", + id: "anthropic.claude-3-opus-20240229-v1:0", name: "Claude V3 Opus", modality: ClaudeV3Opus( parameters: TextGenerationParameters( @@ -123,7 +123,7 @@ extension BedrockModel { ) ) public static let claudev3_5_haiku: BedrockModel = BedrockModel( - id: "us.anthropic.claude-3-5-haiku-20241022-v1:0", + id: "anthropic.claude-3-5-haiku-20241022-v1:0", name: "Claude V3.5 Haiku", modality: ClaudeV3_5Haiku( parameters: TextGenerationParameters( @@ -138,7 +138,7 @@ extension BedrockModel { ) ) public static let claudev3_5_sonnet: BedrockModel = BedrockModel( - id: "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + id: "anthropic.claude-3-5-sonnet-20240620-v1:0", name: "Claude V3.5 Sonnet", modality: ClaudeV3_5Sonnet( parameters: TextGenerationParameters( @@ -153,7 +153,7 @@ extension BedrockModel { ) ) public static let claudev3_5_sonnet_v2: BedrockModel = BedrockModel( - id: "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + id: "anthropic.claude-3-5-sonnet-20241022-v2:0", name: "Claude V3.5 Sonnet V2", modality: ClaudeV3_5Sonnet( parameters: TextGenerationParameters( @@ -168,7 +168,7 @@ extension BedrockModel { ) ) public static let claudev3_7_sonnet: BedrockModel = BedrockModel( - id: "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + id: "anthropic.claude-3-7-sonnet-20250219-v1:0", name: "Claude V3.7 Sonnet", modality: ClaudeV3_7Sonnet( parameters: TextGenerationParameters( @@ -184,7 +184,7 @@ extension BedrockModel { ) ) public static let claude_sonnet_v4: BedrockModel = BedrockModel( - id: "us.anthropic.claude-sonnet-4-20250514-v1:0", + id: "anthropic.claude-sonnet-4-20250514-v1:0", name: "Claude Sonnet v4", modality: Claude_Sonnet_v4( parameters: TextGenerationParameters( @@ -200,7 +200,7 @@ extension BedrockModel { ) ) public static let claude_opus_v4: BedrockModel = BedrockModel( - id: "us.anthropic.claude-opus-4-20250514-v1:0", + id: "anthropic.claude-opus-4-20250514-v1:0", name: "Claude Opus v4", modality: Claude_Opus_v4( parameters: TextGenerationParameters( diff --git a/Sources/Models/DeepSeek/DeepSeek.swift b/Sources/Models/DeepSeek/DeepSeek.swift index f467db02..4531f351 100644 --- a/Sources/Models/DeepSeek/DeepSeek.swift +++ b/Sources/Models/DeepSeek/DeepSeek.swift @@ -30,7 +30,7 @@ import Foundation // filtered to remove the reasoning content blocks before it is sent to the model. // The same goes for ConverseStreamingModality. -struct DeepSeekText: TextModality { +struct DeepSeekText: TextModality, CrossRegionInferenceModality { let parameters: TextGenerationParameters let converseFeatures: [ConverseFeature] let converseParameters: ConverseParameters diff --git a/Sources/Models/DeepSeek/DeepSeekBedrockModels.swift b/Sources/Models/DeepSeek/DeepSeekBedrockModels.swift index c39594de..895f312a 100644 --- a/Sources/Models/DeepSeek/DeepSeekBedrockModels.swift +++ b/Sources/Models/DeepSeek/DeepSeekBedrockModels.swift @@ -19,7 +19,7 @@ typealias DeepSeekR1V1 = DeepSeekText extension BedrockModel { public static let deepseek_r1_v1: BedrockModel = BedrockModel( - id: "us.deepseek.r1-v1:0", + id: "deepseek.r1-v1:0", name: "DeepSeek R1", modality: DeepSeekR1V1( parameters: TextGenerationParameters( diff --git a/Sources/Models/Llama/Llama.swift b/Sources/Models/Llama/Llama.swift index ac8671f5..9ef1b871 100644 --- a/Sources/Models/Llama/Llama.swift +++ b/Sources/Models/Llama/Llama.swift @@ -15,7 +15,7 @@ import Foundation -struct LlamaText: TextModality, ConverseModality, ConverseStreamingModality { +struct LlamaText: TextModality, ConverseModality, ConverseStreamingModality, CrossRegionInferenceModality { func getName() -> String { "Llama Text Generation" } let parameters: TextGenerationParameters diff --git a/Sources/Models/Llama/LlamaBedrockModels.swift b/Sources/Models/Llama/LlamaBedrockModels.swift index ed2f93af..ef26cd79 100644 --- a/Sources/Models/Llama/LlamaBedrockModels.swift +++ b/Sources/Models/Llama/LlamaBedrockModels.swift @@ -49,7 +49,7 @@ extension BedrockModel { ) ) public static let llama3_1_8b_instruct: BedrockModel = BedrockModel( - id: "us.meta.llama3-1-8b-instruct-v1:0", + id: "meta.llama3-1-8b-instruct-v1:0", name: "Llama 3.1 8B Instruct", modality: LlamaText( parameters: TextGenerationParameters( @@ -64,7 +64,7 @@ extension BedrockModel { ) ) public static let llama3_1_70b_instruct: BedrockModel = BedrockModel( - id: "us.meta.llama3-1-70b-instruct-v1:0", + id: "meta.llama3-1-70b-instruct-v1:0", name: "Llama 3.1 70B Instruct", modality: LlamaText( parameters: TextGenerationParameters( @@ -79,7 +79,7 @@ extension BedrockModel { ) ) public static let llama3_2_1b_instruct: BedrockModel = BedrockModel( - id: "us.meta.llama3-2-1b-instruct-v1:0", + id: "meta.llama3-2-1b-instruct-v1:0", name: "Llama 3.2 1B Instruct", modality: LlamaText( parameters: TextGenerationParameters( @@ -94,7 +94,7 @@ extension BedrockModel { ) ) public static let llama3_2_3b_instruct: BedrockModel = BedrockModel( - id: "us.meta.llama3-2-3b-instruct-v1:0", + id: "meta.llama3-2-3b-instruct-v1:0", name: "Llama 3.2 3B Instruct", modality: LlamaText( parameters: TextGenerationParameters( @@ -109,7 +109,7 @@ extension BedrockModel { ) ) public static let llama3_3_70b_instruct: BedrockModel = BedrockModel( - id: "us.meta.llama3-3-70b-instruct-v1:0", + id: "meta.llama3-3-70b-instruct-v1:0", name: "Llama 3.3 70B Instruct", modality: LlamaText( parameters: TextGenerationParameters( diff --git a/Sources/Models/Mistral/MistralBedrockModels.swift b/Sources/Models/Mistral/MistralBedrockModels.swift index a558255a..d298fd9d 100644 --- a/Sources/Models/Mistral/MistralBedrockModels.swift +++ b/Sources/Models/Mistral/MistralBedrockModels.swift @@ -21,6 +21,8 @@ import Foundation typealias MistralConverse = StandardConverse +// TODO: define a new struct to support Pixtral Large 25 + cross region inference + extension BedrockModel { public static let mistral_large_2402 = BedrockModel( id: "mistral.mistral-large-2402-v1:0", diff --git a/Sources/Region.swift b/Sources/Region.swift index 27d36536..90121204 100644 --- a/Sources/Region.swift +++ b/Sources/Region.swift @@ -303,3 +303,18 @@ extension Region { } } } + +// Support for Bedrock cross region inference +// https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html +// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +extension Region { + public func isUSRegion() -> Bool { + return self.rawValue.hasPrefix("us-") + } + public func isEURegion() -> Bool { + return self.rawValue.hasPrefix("eu-") + } + public func isAPRegion() -> Bool { + return self.rawValue.hasPrefix("ap-") + } +} \ No newline at end of file diff --git a/Tests/Mock/MockBedrockClient.swift b/Tests/Mock/MockBedrockClient.swift index d2aa90af..090e327d 100644 --- a/Tests/Mock/MockBedrockClient.swift +++ b/Tests/Mock/MockBedrockClient.swift @@ -14,8 +14,6 @@ //===----------------------------------------------------------------------===// @preconcurrency import AWSBedrock -import AWSClientRuntime -import AWSSDKIdentity import BedrockService import Foundation diff --git a/Tests/Mock/MockBedrockRuntimeClient.swift b/Tests/Mock/MockBedrockRuntimeClient.swift index 4ff289a3..9eaaddbe 100644 --- a/Tests/Mock/MockBedrockRuntimeClient.swift +++ b/Tests/Mock/MockBedrockRuntimeClient.swift @@ -14,8 +14,9 @@ //===----------------------------------------------------------------------===// @preconcurrency import AWSBedrockRuntime -import AWSClientRuntime -import AWSSDKIdentity + +import Testing + import BedrockService import Foundation @@ -172,9 +173,19 @@ public struct MockBedrockRuntimeClient: BedrockRuntimeClientProtocol { message: "Malformed input request, please reformat your input and try again." ) } - let model: BedrockModel = BedrockModel(rawValue: modelId)! - switch model.modality.getName() { + // remove the cross region inference prefix if it exists + // when modelId starts with "us.", "eu.", "ap.", remove it" + let prefixPattern: String = "^(us|eu|ap)\\." + let modelIdWithoutPrefix = modelId.replacingOccurrences( + of: prefixPattern, + with: "", + options: .regularExpression + ) + let model: BedrockModel? = BedrockModel(rawValue: modelIdWithoutPrefix) + #expect(model != nil, "Model with id \(modelIdWithoutPrefix) not found") + + switch model?.modality.getName() { case "Amazon Image Generation": return InvokeModelOutput(body: try getImageGeneration(body: inputBody)) case "Nova Text Generation": From d8f12aaf0b1d6da4ca2aa69226203b5899bd291e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Stormacq?= Date: Sat, 7 Jun 2025 20:52:40 +0200 Subject: [PATCH 2/3] swift-format --- .../BedrockService+InvokeModelImage.swift | 40 +++++++++---------- Sources/InvokeModel/InvokeModelRequest.swift | 2 +- Sources/Modalities/CrossRegionInference.swift | 12 +++--- .../Models/Mistral/MistralBedrockModels.swift | 2 +- Sources/Region.swift | 8 ++-- Tests/Mock/MockBedrockRuntimeClient.swift | 4 +- 6 files changed, 33 insertions(+), 35 deletions(-) diff --git a/Sources/InvokeModel/BedrockService+InvokeModelImage.swift b/Sources/InvokeModel/BedrockService+InvokeModelImage.swift index 9150e960..6061926c 100644 --- a/Sources/InvokeModel/BedrockService+InvokeModelImage.swift +++ b/Sources/InvokeModel/BedrockService+InvokeModelImage.swift @@ -160,31 +160,31 @@ extension BedrockService { /// Sends the request to invoke the model and returns the generated image(s) private func sendRequest(request: InvokeModelRequest, model: BedrockModel) async throws -> ImageGenerationOutput { - let input: InvokeModelInput = try request.getInvokeModelInput(forRegion: self.region) + let input: InvokeModelInput = try request.getInvokeModelInput(forRegion: self.region) + logger.trace( + "Sending request to invokeModel", + metadata: [ + "model": .string(model.id), "request": .string(String(describing: input)), + ] + ) + let response = try await self.bedrockRuntimeClient.invokeModel(input: input) + guard let responseBody = response.body else { logger.trace( - "Sending request to invokeModel", + "Invalid response", metadata: [ - "model": .string(model.id), "request": .string(String(describing: input)), + "response": .string(String(describing: response)), + "hasBody": .stringConvertible(response.body != nil), ] ) - let response = try await self.bedrockRuntimeClient.invokeModel(input: input) - guard let responseBody = response.body else { - logger.trace( - "Invalid response", - metadata: [ - "response": .string(String(describing: response)), - "hasBody": .stringConvertible(response.body != nil), - ] - ) - throw BedrockLibraryError.invalidSDKResponse( - "Something went wrong while extracting body from response." - ) - } - let invokemodelResponse: InvokeModelResponse = try InvokeModelResponse.createImageResponse( - body: responseBody, - model: model + throw BedrockLibraryError.invalidSDKResponse( + "Something went wrong while extracting body from response." ) - return try invokemodelResponse.getGeneratedImage() + } + let invokemodelResponse: InvokeModelResponse = try InvokeModelResponse.createImageResponse( + body: responseBody, + model: model + ) + return try invokemodelResponse.getGeneratedImage() } /// Generates 1 to 5 image variation(s) from reference images and a text prompt using a specific model diff --git a/Sources/InvokeModel/InvokeModelRequest.swift b/Sources/InvokeModel/InvokeModelRequest.swift index 86868152..6898b996 100644 --- a/Sources/InvokeModel/InvokeModelRequest.swift +++ b/Sources/InvokeModel/InvokeModelRequest.swift @@ -206,7 +206,7 @@ struct InvokeModelRequest { /// - Returns: A configured InvokeModelInput containing the model ID, content type, and encoded request body /// - Throws: BedrockLibraryError.encodingError if the request body cannot be encoded to JSON public func getInvokeModelInput(forRegion region: Region) throws -> InvokeModelInput { - do { + do { let jsonData: Data = try JSONEncoder().encode(self.body) return InvokeModelInput( accept: self.accept.headerValue, diff --git a/Sources/Modalities/CrossRegionInference.swift b/Sources/Modalities/CrossRegionInference.swift index b869839d..399a98e4 100644 --- a/Sources/Modalities/CrossRegionInference.swift +++ b/Sources/Modalities/CrossRegionInference.swift @@ -15,10 +15,10 @@ public protocol CrossRegionInferenceModality: Sendable {} extension CrossRegionInferenceModality { - public func crossRegionPrefix(forRegion region: Region) -> String { - if region.isEURegion() { return "eu." } - if region.isUSRegion() { return "us." } - if region.isAPRegion() { return "ap." } - return "" - } + public func crossRegionPrefix(forRegion region: Region) -> String { + if region.isEURegion() { return "eu." } + if region.isUSRegion() { return "us." } + if region.isAPRegion() { return "ap." } + return "" + } } diff --git a/Sources/Models/Mistral/MistralBedrockModels.swift b/Sources/Models/Mistral/MistralBedrockModels.swift index d298fd9d..de5e6c3e 100644 --- a/Sources/Models/Mistral/MistralBedrockModels.swift +++ b/Sources/Models/Mistral/MistralBedrockModels.swift @@ -21,7 +21,7 @@ import Foundation typealias MistralConverse = StandardConverse -// TODO: define a new struct to support Pixtral Large 25 + cross region inference +// TODO: define a new struct to support Pixtral Large 25 + cross region inference extension BedrockModel { public static let mistral_large_2402 = BedrockModel( diff --git a/Sources/Region.swift b/Sources/Region.swift index 90121204..73796e6a 100644 --- a/Sources/Region.swift +++ b/Sources/Region.swift @@ -309,12 +309,12 @@ extension Region { // https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html extension Region { public func isUSRegion() -> Bool { - return self.rawValue.hasPrefix("us-") + self.rawValue.hasPrefix("us-") } public func isEURegion() -> Bool { - return self.rawValue.hasPrefix("eu-") + self.rawValue.hasPrefix("eu-") } public func isAPRegion() -> Bool { - return self.rawValue.hasPrefix("ap-") + self.rawValue.hasPrefix("ap-") } -} \ No newline at end of file +} diff --git a/Tests/Mock/MockBedrockRuntimeClient.swift b/Tests/Mock/MockBedrockRuntimeClient.swift index 9eaaddbe..844e1b87 100644 --- a/Tests/Mock/MockBedrockRuntimeClient.swift +++ b/Tests/Mock/MockBedrockRuntimeClient.swift @@ -14,11 +14,9 @@ //===----------------------------------------------------------------------===// @preconcurrency import AWSBedrockRuntime - -import Testing - import BedrockService import Foundation +import Testing public struct MockBedrockRuntimeClient: BedrockRuntimeClientProtocol { public init() {} From b8f56c6bc954c0e3260038528643d7a0981f101e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Stormacq?= Date: Sat, 7 Jun 2025 20:58:01 +0200 Subject: [PATCH 3/3] Update Sources/InvokeModel/BedrockService+InvokeModelImage.swift Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- Sources/InvokeModel/BedrockService+InvokeModelImage.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/InvokeModel/BedrockService+InvokeModelImage.swift b/Sources/InvokeModel/BedrockService+InvokeModelImage.swift index 6061926c..322b7b35 100644 --- a/Sources/InvokeModel/BedrockService+InvokeModelImage.swift +++ b/Sources/InvokeModel/BedrockService+InvokeModelImage.swift @@ -154,7 +154,7 @@ extension BedrockService { ) return try await sendRequest(request: request, model: model) } catch { - try handleCommonError(error, context: "listing foundation models") + try handleCommonError(error, context: "invoking image model") } }