-
Notifications
You must be signed in to change notification settings - Fork 568
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add amazonbedrock AI provider Signed-off-by: Su Wei <suwei007@gmail.com> * add amazonbedrock, change model list to const var Signed-off-by: Su Wei <suwei007@gmail.com> * update iai config and auth cmd, add providerRegion Signed-off-by: Wei Su <wsuam@amazon.com> * fix filename wrong Signed-off-by: Wei Su <wsuam@amazon.com> * chore: added some doc info Signed-off-by: Alex Jones <alexsimonjones@gmail.com> --------- Signed-off-by: Su Wei <suwei007@gmail.com> Signed-off-by: Wei Su <wsuam@amazon.com> Signed-off-by: Alex Jones <alexsimonjones@gmail.com> Co-authored-by: Wei Su <wsuam@amazon.com> Co-authored-by: Aris Boutselis <aris.boutselis@senseon.io> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
- Loading branch information
1 parent
4af0ad0
commit f1a7801
Showing
6 changed files
with
270 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
package ai | ||
|
||
import ( | ||
"context" | ||
"encoding/base64" | ||
"encoding/json" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/fatih/color" | ||
|
||
"github.com/k8sgpt-ai/k8sgpt/pkg/cache" | ||
"github.com/k8sgpt-ai/k8sgpt/pkg/util" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/session" | ||
"github.com/aws/aws-sdk-go/service/bedrockruntime" | ||
) | ||
|
||
// AmazonBedRockClient represents the client for interacting with the Amazon Bedrock service. | ||
type AmazonBedRockClient struct { | ||
client *bedrockruntime.BedrockRuntime | ||
language string | ||
model string | ||
temperature float32 | ||
} | ||
|
||
// InvokeModelResponseBody represents the response body structure from the model invocation. | ||
type InvokeModelResponseBody struct { | ||
Completion string `json:"completion"` | ||
Stop_reason string `json:"stop_reason"` | ||
} | ||
|
||
// Amazon BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) | ||
// https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions | ||
const BEDROCK_DEFAULT_REGION = "us-east-1" // default use us-east-1 region | ||
|
||
const ( | ||
US_East_1 = "us-east-1" | ||
US_West_2 = "us-west-2" | ||
AP_Southeast_1 = "ap-southeast-1" | ||
AP_Northeast_1 = "ap-northeast-1" | ||
EU_Central_1 = "eu-central-1" | ||
) | ||
|
||
var BEDROCKER_SUPPORTED_REGION = []string{ | ||
US_East_1, | ||
US_West_2, | ||
AP_Southeast_1, | ||
AP_Northeast_1, | ||
EU_Central_1, | ||
} | ||
|
||
const ( | ||
ModelAnthropicClaudeV2 = "anthropic.claude-v2" | ||
ModelAnthropicClaudeV1 = "anthropic.claude-v1" | ||
ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1" | ||
) | ||
|
||
var BEDROCK_MODELS = []string{ | ||
ModelAnthropicClaudeV2, | ||
ModelAnthropicClaudeV1, | ||
ModelAnthropicClaudeInstantV1, | ||
} | ||
|
||
// GetModelOrDefault check config model | ||
func GetModelOrDefault(model string) string { | ||
|
||
// Check if the provided model is in the list | ||
for _, m := range BEDROCK_MODELS { | ||
if m == model { | ||
return model // Return the provided model | ||
} | ||
} | ||
|
||
// Return the default model if the provided model is not in the list | ||
return BEDROCK_MODELS[0] | ||
} | ||
|
||
// GetModelOrDefault check config region | ||
func GetRegionOrDefault(region string) string { | ||
|
||
// Check if the provided model is in the list | ||
for _, m := range BEDROCKER_SUPPORTED_REGION { | ||
if m == region { | ||
return region // Return the provided model | ||
} | ||
} | ||
|
||
// Return the default model if the provided model is not in the list | ||
return BEDROCK_DEFAULT_REGION | ||
} | ||
|
||
// Configure configures the AmazonBedRockClient with the provided configuration and language. | ||
func (a *AmazonBedRockClient) Configure(config IAIConfig, language string) error { | ||
|
||
// Create a new AWS session | ||
providerRegion := GetRegionOrDefault(config.GetProviderRegion()) | ||
|
||
sess, err := session.NewSession(&aws.Config{ | ||
Region: aws.String(providerRegion), | ||
}) | ||
|
||
if err != nil { | ||
return err | ||
} | ||
|
||
// Create a new BedrockRuntime client | ||
a.client = bedrockruntime.New(sess) | ||
a.language = language | ||
a.model = GetModelOrDefault(config.GetModel()) | ||
a.temperature = config.GetTemperature() | ||
|
||
return nil | ||
} | ||
|
||
// GetCompletion sends a request to the model for generating completion based on the provided prompt. | ||
func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string, promptTmpl string) (string, error) { | ||
|
||
// Prepare the input data for the model invocation | ||
request := map[string]interface{}{ | ||
"prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), | ||
"max_tokens_to_sample": 1024, | ||
"temperature": a.temperature, | ||
"top_p": 0.9, | ||
} | ||
|
||
body, err := json.Marshal(request) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
// Build the parameters for the model invocation | ||
params := &bedrockruntime.InvokeModelInput{ | ||
Body: body, | ||
ModelId: aws.String(a.model), | ||
ContentType: aws.String("application/json"), | ||
Accept: aws.String("application/json"), | ||
} | ||
// Invoke the model | ||
resp, err := a.client.InvokeModelWithContext(ctx, params) | ||
|
||
if err != nil { | ||
return "", err | ||
} | ||
// Parse the response body | ||
output := &InvokeModelResponseBody{} | ||
err = json.Unmarshal(resp.Body, output) | ||
if err != nil { | ||
return "", err | ||
} | ||
return output.Completion, nil | ||
} | ||
|
||
// Parse generates a completion for the provided prompt using the Amazon Bedrock model. | ||
func (a *AmazonBedRockClient) Parse(ctx context.Context, prompt []string, cache cache.ICache, promptTmpl string) (string, error) { | ||
inputKey := strings.Join(prompt, " ") | ||
// Check for cached data | ||
cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) | ||
|
||
if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { | ||
response, err := cache.Load(cacheKey) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
if response != "" { | ||
output, err := base64.StdEncoding.DecodeString(response) | ||
if err != nil { | ||
color.Red("error decoding cached data: %v", err) | ||
return "", nil | ||
} | ||
return string(output), nil | ||
} | ||
} | ||
|
||
response, err := a.GetCompletion(ctx, inputKey, promptTmpl) | ||
|
||
if err != nil { | ||
return "", err | ||
} | ||
|
||
err = cache.Store(cacheKey, base64.StdEncoding.EncodeToString([]byte(response))) | ||
|
||
if err != nil { | ||
color.Red("error storing value to cache: %v", err) | ||
return "", nil | ||
} | ||
|
||
return response, nil | ||
} | ||
|
||
// GetName returns the name of the AmazonBedRockClient. | ||
func (a *AmazonBedRockClient) GetName() string { | ||
return "amazonbedrock" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters