Skip to content

Commit

Permalink
feat: add global secrets to StabilityAI connector (instill-ai#122)
Browse files Browse the repository at this point in the history
Because

- We want to support global secrets on StabilityAI

This commit

- Adds global secrets and usage handler on StabilityAI
  • Loading branch information
jvallesm committed May 13, 2024
1 parent 2d61f24 commit 1db0c9f
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 46 deletions.
17 changes: 17 additions & 0 deletions pkg/base/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package base
import (
"encoding/json"
"fmt"
"strings"

"github.com/gofrs/uuid"
"go.uber.org/zap"
Expand Down Expand Up @@ -304,3 +305,19 @@ func (e *ConnectorExecution) UsesSecret() bool {
func (e *ConnectorExecution) UsageHandlerCreator() UsageHandlerCreator {
return e.Connector.UsageHandlerCreator()
}

// ReadFromSecrets reads a component secret from a secret map that comes from
// environment variable configuration.
//
// Connection parameters are defined with snake_case, but the
// environment variable configuration loader replaces underscores by dots,
// so we can't use the parameter key directly.
// TODO using camelCase in configuration fields would fix this issue.
func ReadFromSecrets(key string, secrets map[string]any) string {
sanitized := strings.ReplaceAll(key, "_", "")
if v, ok := secrets[sanitized].(string); ok {
return v
}

return ""
}
11 changes: 10 additions & 1 deletion pkg/connector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,16 @@ func Init(
connectorIDMap: map[string]*connector{},
}

conStore.Import(stabilityai.Init(baseConn))
{
// StabilityAI
conn := stabilityai.Init(baseConn)

// Secret doesn't allow hyphens
conn = conn.WithSecrets(secrets["stabilityai"]).
WithUsageHandlerCreator(usageHandlerCreators[conn.GetID()])
conStore.Import(conn)
}

conStore.Import(instill.Init(baseConn))
conStore.Import(huggingface.Init(baseConn))

Expand Down
15 changes: 1 addition & 14 deletions pkg/connector/openai/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"sync"

"github.com/gabriel-vasile/mimetype"
Expand Down Expand Up @@ -64,22 +63,10 @@ func Init(bc base.Connector) *Connector {
return con
}

// The connection parameter is defined with snake_case, but the
// environment variable configuration loader replaces underscores by dots,
// so we can't use the parameter key directly.
func readFromSecrets(key string, s map[string]any) string {
sanitized := strings.ReplaceAll(key, "_", "")
if v, ok := s[sanitized].(string); ok {
return v
}

return ""
}

// WithSecrets loads secrets into the connector, which can be used to configure
// it with globaly defined parameters.
func (c *Connector) WithSecrets(s map[string]any) *Connector {
c.secretAPIKey = readFromSecrets(cfgAPIKey, s)
c.secretAPIKey = base.ReadFromSecrets(cfgAPIKey, s)

return c
}
Expand Down
18 changes: 18 additions & 0 deletions pkg/connector/stabilityai/v0/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ type errBody struct {
func (e errBody) Message() string {
return e.Msg
}

// getBasePath returns Stability AI's API URL. This configuration param allows
// us to override the API the connector will point to. It isn't meant to be
// exposed to users. Rather, it can serve to test the logic against a fake
// server.
// TODO instead of having the API value hardcoded in the codebase, it should be
// read from a config file or environment variable.
func getBasePath(config *structpb.Struct) string {
v, ok := config.GetFields()["base_path"]
if !ok {
return host
}
return v.GetStringValue()
}

func getAPIKey(config *structpb.Struct) string {
return config.GetFields()[cfgAPIKey].GetStringValue()
}
4 changes: 2 additions & 2 deletions pkg/connector/stabilityai/v0/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestConnector_ExecuteImageFromText(t *testing.T) {
})
c.Assert(err, qt.IsNil)

exec, err := connector.CreateExecution(nil, connection, textToImageTask)
exec, err := connector.CreateExecution(nil, connection, TextToImageTask)
c.Assert(err, qt.IsNil)

weights := []float64{weight}
Expand Down Expand Up @@ -192,7 +192,7 @@ func TestConnector_ExecuteImageFromImage(t *testing.T) {
})
c.Assert(err, qt.IsNil)

exec, err := connector.CreateExecution(nil, connection, imageToImageTask)
exec, err := connector.CreateExecution(nil, connection, ImageToImageTask)
c.Assert(err, qt.IsNil)

weights := []float64{weight}
Expand Down
104 changes: 75 additions & 29 deletions pkg/connector/stabilityai/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ import (
)

const (
host = "https://api.stability.ai"
textToImageTask = "TASK_TEXT_TO_IMAGE"
imageToImageTask = "TASK_IMAGE_TO_IMAGE"
host = "https://api.stability.ai"

TextToImageTask = "TASK_TEXT_TO_IMAGE"
ImageToImageTask = "TASK_IMAGE_TO_IMAGE"

cfgAPIKey = "api_key"
)

var (
Expand All @@ -27,20 +30,21 @@ var (
//go:embed config/stabilityai.json
stabilityaiJSON []byte
once sync.Once
con *connector
con *Connector
)

type connector struct {
// Connector executes queries against StabilityAI.
type Connector struct {
base.Connector
}

type execution struct {
base.ConnectorExecution
usageHandlerCreator base.UsageHandlerCreator
secretAPIKey string
}

func Init(bc base.Connector) *connector {
// Init returns an initialized StabilityAI connector.
func Init(bc base.Connector) *Connector {
once.Do(func() {
con = &connector{Connector: bc}
con = &Connector{Connector: bc}
err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, map[string][]byte{"stabilityai.json": stabilityaiJSON})
if err != nil {
panic(err)
Expand All @@ -50,28 +54,70 @@ func Init(bc base.Connector) *connector {
return con
}

func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
// WithSecrets loads secrets into the connector, which can be used to configure
// it with globaly defined parameters.
func (c *Connector) WithSecrets(s map[string]any) *Connector {
c.secretAPIKey = base.ReadFromSecrets(cfgAPIKey, s)

return c
}

// WithUsageHandlerCreator overrides the UsageHandlerCreator method.
func (c *Connector) WithUsageHandlerCreator(newUH base.UsageHandlerCreator) *Connector {
c.usageHandlerCreator = newUH
return c
}

// UsageHandlerCreator returns a function to initialize a UsageHandler.
func (c *Connector) UsageHandlerCreator() base.UsageHandlerCreator {
if c.usageHandlerCreator == nil {
return c.Connector.UsageHandlerCreator()
}
return c.usageHandlerCreator
}

// resolveSecrets looks for references to a global secret in the connection
// and replaces them by the global secret injected during initialization.
func (c *Connector) resolveSecrets(conn *structpb.Struct) (*structpb.Struct, bool, error) {
apiKey := conn.GetFields()[cfgAPIKey].GetStringValue()
if apiKey != base.SecretKeyword {
return conn, false, nil
}

if c.secretAPIKey == "" {
return nil, false, base.NewUnresolvedSecret(cfgAPIKey)
}

conn.GetFields()[cfgAPIKey] = structpb.NewStringValue(c.secretAPIKey)
return conn, true, nil
}

// CreateExecution initializes a connector executor that can be used in a
// pipeline trigger.
func (c *Connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
resolvedConnection, resolved, err := c.resolveSecrets(connection)
if err != nil {
return nil, err
}

return &base.ExecutionWrapper{Execution: &execution{
ConnectorExecution: base.ConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task},
ConnectorExecution: base.ConnectorExecution{
Connector: c,
SystemVariables: sysVars,
Connection: resolvedConnection,
Task: task,
},
usesSecret: resolved,
}}, nil
}

func getAPIKey(config *structpb.Struct) string {
return config.GetFields()["api_key"].GetStringValue()
type execution struct {
base.ConnectorExecution
usesSecret bool
}

// getBasePath returns Stability AI's API URL. This configuration param allows
// us to override the API the connector will point to. It isn't meant to be
// exposed to users. Rather, it can serve to test the logic against a fake
// server.
// TODO instead of having the API value hardcoded in the codebase, it should be
// read from a config file or environment variable.
func getBasePath(config *structpb.Struct) string {
v, ok := config.GetFields()["base_path"]
if !ok {
return host
}
return v.GetStringValue()
func (e *execution) UsesSecret() bool {
return e.usesSecret
}

func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
Expand All @@ -80,7 +126,7 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st

for _, input := range inputs {
switch e.Task {
case textToImageTask:
case TextToImageTask:
params, err := parseTextToImageReq(input)
if err != nil {
return inputs, err
Expand All @@ -99,7 +145,7 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st
}

outputs = append(outputs, output)
case imageToImageTask:
case ImageToImageTask:
params, err := parseImageToImageReq(input)
if err != nil {
return inputs, err
Expand Down Expand Up @@ -135,7 +181,7 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st
}

// Test checks the connector state.
func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
func (c *Connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
var engines []Engine
req := newClient(connection, c.Logger).R().SetResult(&engines)

Expand Down

0 comments on commit 1db0c9f

Please sign in to comment.