Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add global secrets to StabilityAI connector #122

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pkg/base/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package base
import (
"encoding/json"
"fmt"
"strings"

"github.com/gofrs/uuid"
"go.uber.org/zap"
Expand Down Expand Up @@ -304,3 +305,19 @@ func (e *ConnectorExecution) UsesSecret() bool {
func (e *ConnectorExecution) UsageHandlerCreator() UsageHandlerCreator {
return e.Connector.UsageHandlerCreator()
}

// ReadFromSecrets reads a component secret from a secret map that comes from
// environment variable configuration.
//
// Connection parameters are defined with snake_case, but the
// environment variable configuration loader replaces underscores by dots,
// so we can't use the parameter key directly.
// TODO using camelCase in configuration fields would fix this issue.
func ReadFromSecrets(key string, secrets map[string]any) string {
sanitized := strings.ReplaceAll(key, "_", "")
if v, ok := secrets[sanitized].(string); ok {
return v
}

return ""
}
11 changes: 10 additions & 1 deletion pkg/connector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,16 @@ func Init(
connectorIDMap: map[string]*connector{},
}

conStore.Import(stabilityai.Init(baseConn))
{
// StabilityAI
conn := stabilityai.Init(baseConn)

// Secret doesn't allow hyphens
conn = conn.WithSecrets(secrets["stabilityai"]).
WithUsageHandlerCreator(usageHandlerCreators[conn.GetID()])
conStore.Import(conn)
}

conStore.Import(instill.Init(baseConn))
conStore.Import(huggingface.Init(baseConn))

Expand Down
15 changes: 1 addition & 14 deletions pkg/connector/openai/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"sync"

"github.com/gabriel-vasile/mimetype"
Expand Down Expand Up @@ -64,22 +63,10 @@ func Init(bc base.Connector) *Connector {
return con
}

// The connection parameter is defined with snake_case, but the
// environment variable configuration loader replaces underscores by dots,
// so we can't use the parameter key directly.
func readFromSecrets(key string, s map[string]any) string {
sanitized := strings.ReplaceAll(key, "_", "")
if v, ok := s[sanitized].(string); ok {
return v
}

return ""
}

// WithSecrets loads secrets into the connector, which can be used to configure
// it with globaly defined parameters.
func (c *Connector) WithSecrets(s map[string]any) *Connector {
c.secretAPIKey = readFromSecrets(cfgAPIKey, s)
c.secretAPIKey = base.ReadFromSecrets(cfgAPIKey, s)

return c
}
Expand Down
18 changes: 18 additions & 0 deletions pkg/connector/stabilityai/v0/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ type errBody struct {
func (e errBody) Message() string {
return e.Msg
}

// getBasePath returns Stability AI's API URL. This configuration param allows
// us to override the API the connector will point to. It isn't meant to be
// exposed to users. Rather, it can serve to test the logic against a fake
// server.
// TODO instead of having the API value hardcoded in the codebase, it should be
// read from a config file or environment variable.
func getBasePath(config *structpb.Struct) string {
v, ok := config.GetFields()["base_path"]
if !ok {
return host
}
return v.GetStringValue()
}

func getAPIKey(config *structpb.Struct) string {
return config.GetFields()[cfgAPIKey].GetStringValue()
}
4 changes: 2 additions & 2 deletions pkg/connector/stabilityai/v0/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestConnector_ExecuteImageFromText(t *testing.T) {
})
c.Assert(err, qt.IsNil)

exec, err := connector.CreateExecution(nil, connection, textToImageTask)
exec, err := connector.CreateExecution(nil, connection, TextToImageTask)
c.Assert(err, qt.IsNil)

weights := []float64{weight}
Expand Down Expand Up @@ -192,7 +192,7 @@ func TestConnector_ExecuteImageFromImage(t *testing.T) {
})
c.Assert(err, qt.IsNil)

exec, err := connector.CreateExecution(nil, connection, imageToImageTask)
exec, err := connector.CreateExecution(nil, connection, ImageToImageTask)
c.Assert(err, qt.IsNil)

weights := []float64{weight}
Expand Down
104 changes: 75 additions & 29 deletions pkg/connector/stabilityai/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ import (
)

const (
host = "https://api.stability.ai"
textToImageTask = "TASK_TEXT_TO_IMAGE"
imageToImageTask = "TASK_IMAGE_TO_IMAGE"
host = "https://api.stability.ai"

TextToImageTask = "TASK_TEXT_TO_IMAGE"
ImageToImageTask = "TASK_IMAGE_TO_IMAGE"

cfgAPIKey = "api_key"
Dismissed Show dismissed Hide dismissed
)

var (
Expand All @@ -27,20 +30,21 @@ var (
//go:embed config/stabilityai.json
stabilityaiJSON []byte
once sync.Once
con *connector
con *Connector
)

type connector struct {
// Connector executes queries against StabilityAI.
type Connector struct {
base.Connector
}

type execution struct {
base.ConnectorExecution
usageHandlerCreator base.UsageHandlerCreator
secretAPIKey string
}

func Init(bc base.Connector) *connector {
// Init returns an initialized StabilityAI connector.
func Init(bc base.Connector) *Connector {
once.Do(func() {
con = &connector{Connector: bc}
con = &Connector{Connector: bc}
err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, map[string][]byte{"stabilityai.json": stabilityaiJSON})
if err != nil {
panic(err)
Expand All @@ -50,28 +54,70 @@ func Init(bc base.Connector) *connector {
return con
}

func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
// WithSecrets loads secrets into the connector, which can be used to configure
// it with globaly defined parameters.
func (c *Connector) WithSecrets(s map[string]any) *Connector {
c.secretAPIKey = base.ReadFromSecrets(cfgAPIKey, s)

return c
}

// WithUsageHandlerCreator overrides the UsageHandlerCreator method.
func (c *Connector) WithUsageHandlerCreator(newUH base.UsageHandlerCreator) *Connector {
c.usageHandlerCreator = newUH
return c
}

// UsageHandlerCreator returns a function to initialize a UsageHandler.
func (c *Connector) UsageHandlerCreator() base.UsageHandlerCreator {
if c.usageHandlerCreator == nil {
return c.Connector.UsageHandlerCreator()
}
return c.usageHandlerCreator
}

// resolveSecrets looks for references to a global secret in the connection
// and replaces them by the global secret injected during initialization.
func (c *Connector) resolveSecrets(conn *structpb.Struct) (*structpb.Struct, bool, error) {
apiKey := conn.GetFields()[cfgAPIKey].GetStringValue()
if apiKey != base.SecretKeyword {
return conn, false, nil
}

if c.secretAPIKey == "" {
return nil, false, base.NewUnresolvedSecret(cfgAPIKey)
}

conn.GetFields()[cfgAPIKey] = structpb.NewStringValue(c.secretAPIKey)
return conn, true, nil
}

// CreateExecution initializes a connector executor that can be used in a
// pipeline trigger.
func (c *Connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
resolvedConnection, resolved, err := c.resolveSecrets(connection)
if err != nil {
return nil, err
}

return &base.ExecutionWrapper{Execution: &execution{
ConnectorExecution: base.ConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task},
ConnectorExecution: base.ConnectorExecution{
Connector: c,
SystemVariables: sysVars,
Connection: resolvedConnection,
Task: task,
},
usesSecret: resolved,
}}, nil
}

func getAPIKey(config *structpb.Struct) string {
return config.GetFields()["api_key"].GetStringValue()
type execution struct {
base.ConnectorExecution
usesSecret bool
}

// getBasePath returns Stability AI's API URL. This configuration param allows
// us to override the API the connector will point to. It isn't meant to be
// exposed to users. Rather, it can serve to test the logic against a fake
// server.
// TODO instead of having the API value hardcoded in the codebase, it should be
// read from a config file or environment variable.
func getBasePath(config *structpb.Struct) string {
v, ok := config.GetFields()["base_path"]
if !ok {
return host
}
return v.GetStringValue()
func (e *execution) UsesSecret() bool {
return e.usesSecret
}

func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
Expand All @@ -80,7 +126,7 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st

for _, input := range inputs {
switch e.Task {
case textToImageTask:
case TextToImageTask:
params, err := parseTextToImageReq(input)
if err != nil {
return inputs, err
Expand All @@ -99,7 +145,7 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st
}

outputs = append(outputs, output)
case imageToImageTask:
case ImageToImageTask:
params, err := parseImageToImageReq(input)
if err != nil {
return inputs, err
Expand Down Expand Up @@ -135,7 +181,7 @@ func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*st
}

// Test checks the connector state.
func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
func (c *Connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
var engines []Engine
req := newClient(connection, c.Logger).R().SetResult(&engines)

Expand Down
Loading