Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 87 additions & 86 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,65 +35,100 @@ import (
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)

// Plugin is a common interface for plugins.
type Plugin interface {
// Name returns the name of the plugin.
Name() string
// Init initializes the plugin.
Init(ctx context.Context, g *Genkit) error
}

// Genkit encapsulates a Genkit instance including the registry and configuration.
type Genkit struct {
// Registry for all actions contained in this instance.
reg *registry.Registry
// Params to configure calls using this instance.
Params *GenkitParams
reg *registry.Registry // Registry for all actions contained in this instance.
DefaultModel string // The default model to use if no model is specified.
PromptDir string // Directory where dotprompts are stored.
}

type genkitOption = func(params *GenkitParams) error
// genkitOptions are options for configuring the Genkit instance.
type genkitOptions struct {
DefaultModel string // The default model to use if no model is specified.
PromptDir string // Directory where dotprompts are stored.
Plugins []Plugin // Plugin to initialize automatically.
}

type GenkitParams struct {
DefaultModel string // The default model to use if no model is specified.
PromptDir string // Directory where dotprompts are stored.
type GenkitOption interface {
apply(g *genkitOptions) error
}

// WithDefaultModel sets the default model to use if no model is specified.
func WithDefaultModel(model string) genkitOption {
return func(params *GenkitParams) error {
if params.DefaultModel != "" {
return errors.New("genkit.WithDefaultModel: cannot set DefaultModel more than once")
// apply applies the options to the Genkit options.
func (o *genkitOptions) apply(gOpts *genkitOptions) error {
if o.DefaultModel != "" {
if gOpts.DefaultModel != "" {
return errors.New("cannot set default model more than once (WithDefaultModel)")
}
params.DefaultModel = model
return nil
gOpts.DefaultModel = o.DefaultModel
}
}

// WithPromptDir sets the directory where dotprompts are stored. Defaults to "prompts" at project root.
func WithPromptDir(dir string) genkitOption {
return func(params *GenkitParams) error {
if params.PromptDir != "" {
return errors.New("genkit.WithPromptDir: cannot set PromptDir more than once")
if o.PromptDir != "" {
if gOpts.PromptDir != "" {
return errors.New("cannot set prompt directory more than once (WithPromptDir)")
}
gOpts.PromptDir = o.PromptDir
}

if len(o.Plugins) > 0 {
if gOpts.Plugins != nil {
return errors.New("cannot set plugins more than once (WithPlugins)")
}
params.PromptDir = dir
return nil
gOpts.Plugins = o.Plugins
}

return nil
}

// WithPlugins sets the plugins to use.
func WithPlugins(plugins ...Plugin) GenkitOption {
return &genkitOptions{Plugins: plugins}
}

// WithDefaultModel sets the default model to use if no model is specified.
func WithDefaultModel(model string) GenkitOption {
return &genkitOptions{DefaultModel: model}
}

// WithPromptDir sets the directory where dotprompts are stored. Defaults to "prompts" at project root.
func WithPromptDir(dir string) GenkitOption {
return &genkitOptions{PromptDir: dir}
}

// Init creates a new Genkit instance.
//
// During local development (GENKIT_ENV=dev), it starts the Reflection API server (default :3100) as a side effect.
func Init(ctx context.Context, opts ...genkitOption) (*Genkit, error) {
func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) {
ctx, _ = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)

r, err := registry.New()
if err != nil {
return nil, err
}

params := &GenkitParams{}
gOpts := &genkitOptions{}
for _, opt := range opts {
if err := opt(params); err != nil {
return nil, err
if err := opt.apply(gOpts); err != nil {
return nil, fmt.Errorf("genkit.Init: error applying options: %w", err)
}
}

if params.DefaultModel != "" {
_, err := modelRefParts(params.DefaultModel)
if err != nil {
return nil, err
g := &Genkit{
reg: r,
DefaultModel: gOpts.DefaultModel,
PromptDir: gOpts.PromptDir,
}

for _, plugin := range gOpts.Plugins {
if err := plugin.Init(ctx, g); err != nil {
return nil, fmt.Errorf("genkit.Init: plugin %T initialization failed: %w", plugin, err)
}
}

Expand All @@ -112,30 +147,23 @@ func Init(ctx context.Context, opts ...genkitOption) (*Genkit, error) {

select {
case err := <-errCh:
return nil, fmt.Errorf("reflection server startup failed: %w", err)
return nil, fmt.Errorf("genkit.Init: reflection server startup failed: %w", err)
case <-serverStartCh:
slog.Debug("reflection server started successfully")
case <-ctx.Done():
return nil, ctx.Err()
}
}

return &Genkit{
reg: r,
Params: params,
}, nil
return g, nil
}

// DefineFlow creates a Flow that runs fn, and registers it as an action. fn takes an input of type In and returns an output of type Out.
func DefineFlow[In, Out any](
g *Genkit,
name string,
fn core.Func[In, Out],
) *core.Flow[In, Out, struct{}] {
// DefineFlow creates a [core.Flow] that runs fn, and registers it as a [core.Action]. fn takes an input of type In and returns an output of type Out.
func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *core.Flow[In, Out, struct{}] {
return core.DefineFlow(g.reg, name, fn)
}

// DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action.
// DefineStreamingFlow creates a streaming [core.Flow] that runs fn, and registers it as a [core.Action].
//
// fn takes an input of type In and returns an output of type Out, optionally
// streaming values of type Stream incrementally by invoking a callback.
Expand All @@ -144,11 +172,7 @@ func DefineFlow[In, Out any](
// stream the results by invoking the callback periodically, ultimately returning
// with a final return value that includes all the streamed data.
// Otherwise, it should ignore the callback and just return a result.
func DefineStreamingFlow[In, Out, Stream any](
g *Genkit,
name string,
fn core.StreamingFunc[In, Out, Stream],
) *core.Flow[In, Out, Stream] {
func DefineStreamingFlow[In, Out, Stream any](g *Genkit, name string, fn core.StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] {
return core.DefineStreamingFlow(g.reg, name, fn)
}

Expand All @@ -175,28 +199,23 @@ func ListFlows(g *Genkit) []core.Action {
return flows
}

// DefineModel registers the given generate function as an action, and returns a [Model] that runs it.
func DefineModel(
g *Genkit,
provider, name string,
info *ai.ModelInfo,
generate ai.ModelFunc,
) ai.Model {
return ai.DefineModel(g.reg, provider, name, info, generate)
// DefineModel registers the given generate function as an action, and returns a [ai.Model] that runs it.
func DefineModel(g *Genkit, provider, name string, info *ai.ModelInfo, fn ai.ModelFunc) ai.Model {
return ai.DefineModel(g.reg, provider, name, info, fn)
}

// IsDefinedModel reports whether a model is defined.
func IsDefinedModel(g *Genkit, provider, name string) bool {
return ai.IsDefinedModel(g.reg, provider, name)
}

// LookupModel looks up a [Model] registered by [DefineModel].
// LookupModel looks up a [ai.Model] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(g *Genkit, provider, name string) ai.Model {
return ai.LookupModel(g.reg, provider, name)
}

// DefineTool defines a tool to be passed to a model generate call.
// DefineTool defines a [ai.Tool] to be passed to a model generate call.
func DefineTool[In, Out any](g *Genkit, name, description string, fn func(ctx *ai.ToolContext, input In) (Out, error)) *ai.ToolDef[In, Out] {
return ai.DefineTool(g.reg, name, description, fn)
}
Expand Down Expand Up @@ -227,8 +246,8 @@ func LookupPrompt(g *Genkit, provider, name string) *ai.Prompt {
}

// GenerateWithRequest generates a model response using the given options, middleware, and streaming callback. This is to be used in conjunction with DefinePrompt and Prompt.Render().
func GenerateWithRequest(ctx context.Context, g *Genkit, req *ai.GenerateActionOptions, mw []ai.ModelMiddleware, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) {
return ai.GenerateWithRequest(ctx, g.reg, req, mw, cb)
func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.GenerateActionOptions, mw []ai.ModelMiddleware, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) {
return ai.GenerateWithRequest(ctx, g.reg, actionOpts, mw, cb)
}

// Generate generates a model response using the given options.
Expand All @@ -252,11 +271,6 @@ func DefineIndexer(g *Genkit, provider, name string, index func(context.Context,
return ai.DefineIndexer(g.reg, provider, name, index)
}

// IsDefinedIndexer reports whether an [Indexer] is defined.
func IsDefinedIndexer(g *Genkit, provider, name string) bool {
return ai.IsDefinedIndexer(g.reg, provider, name)
}

// LookupIndexer looks up an [Indexer] registered by [DefineIndexer].
// It returns nil if the model was not defined.
func LookupIndexer(g *Genkit, provider, name string) ai.Indexer {
Expand All @@ -269,11 +283,6 @@ func DefineRetriever(g *Genkit, provider, name string, ret func(context.Context,
return ai.DefineRetriever(g.reg, provider, name, ret)
}

// IsDefinedRetriever reports whether a [Retriever] is defined.
func IsDefinedRetriever(g *Genkit, provider, name string) bool {
return ai.IsDefinedRetriever(g.reg, provider, name)
}

// LookupRetriever looks up a [Retriever] registered by [DefineRetriever].
// It returns nil if the retriever was not defined.
func LookupRetriever(g *Genkit, provider, name string) ai.Retriever {
Expand All @@ -286,17 +295,18 @@ func DefineEmbedder(g *Genkit, provider, name string, embed func(context.Context
return ai.DefineEmbedder(g.reg, provider, name, embed)
}

// IsDefinedEmbedder reports whether an embedder is defined.
func IsDefinedEmbedder(g *Genkit, provider, name string) bool {
return ai.IsDefinedEmbedder(g.reg, provider, name)
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(g *Genkit, provider, name string) ai.Embedder {
return ai.LookupEmbedder(g.reg, provider, name)
}

// LookupPlugin looks up a plugin registered on initialization.
// It returns nil if the plugin was not registered.
func LookupPlugin(g *Genkit, name string) any {
return g.reg.LookupPlugin(name)
}

// DefineEvaluator registers the given evaluator function as an action, and
// returns a [Evaluator] that runs it. This method process the input dataset
// one-by-one.
Expand Down Expand Up @@ -338,17 +348,8 @@ func RegisterSpanProcessor(g *Genkit, sp sdktrace.SpanProcessor) {

// optsWithDefaults prepends defaults to the options so that they can be overridden by the caller.
func optsWithDefaults(g *Genkit, opts []ai.GenerateOption) []ai.GenerateOption {
if g.Params.DefaultModel != "" {
opts = append([]ai.GenerateOption{ai.WithModelName(g.Params.DefaultModel)}, opts...)
if g.DefaultModel != "" {
opts = append([]ai.GenerateOption{ai.WithModelName(g.DefaultModel)}, opts...)
}
return opts
}

// modelRefParts parses a model string into a provider and name.
func modelRefParts(model string) ([]string, error) {
parts := strings.Split(model, "/")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid model format %q, expected provider/name", model)
}
return parts, nil
}
4 changes: 2 additions & 2 deletions go/internal/doc-snippets/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func dot01() error {
ctx = context.Background()

// Default to the project in GCLOUD_PROJECT and the location "us-central1".
vertexai.Init(ctx, g, nil)
(&vertexai.VertexAI{}).Init(ctx, g)

// The .prompt file specifies vertexai/gemini-2.0-flash, which is
// automatically defined by Init(). However, if it specified a model that
Expand Down Expand Up @@ -98,7 +98,7 @@ func dot02() {

// [START dot02]
// Make sure you set up the model you're using.
vertexai.DefineModel(g, "gemini-2.0-flash", nil)
vertexai.Model(g, "gemini-2.0-flash")

response, err := prompt.Execute(
context.Background(),
Expand Down
8 changes: 2 additions & 6 deletions go/internal/doc-snippets/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,12 @@ func gcpEx(ctx context.Context) error {
log.Fatal(err)
}
// [START init]
if err := googlecloud.Init(
ctx,
g,
googlecloud.Config{ProjectID: "your-google-cloud-project"},
); err != nil {
if err := (&googlecloud.GoogleCloud{ProjectID: "your-google-cloud-project"}).Init(ctx, g); err != nil {
return err
}
// [END init]

_ = googlecloud.Config{
_ = googlecloud.GoogleCloud{
ProjectID: "your-google-cloud-project",
ForceExport: true,
MetricInterval: 45e9,
Expand Down
4 changes: 2 additions & 2 deletions go/internal/doc-snippets/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ func googleaiEx(ctx context.Context) error {
}

// [START init]
if err := googleai.Init(ctx, g, nil); err != nil {
if err := (&googleai.GoogleAI{}).Init(ctx, g); err != nil {
return err
}
// [END init]

yourKey := ""
// [START initkey]
if err := googleai.Init(ctx, g, &googleai.Config{APIKey: yourKey}); err != nil {
if err := (&googleai.GoogleAI{APIKey: yourKey}).Init(ctx, g); err != nil {
return err
}
// [END initkey]
Expand Down
10 changes: 3 additions & 7 deletions go/internal/doc-snippets/init/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,12 @@ import (
func main() {
ctx := context.Background()

g, err := genkit.Init(ctx)
if err != nil {
log.Fatal(err)
}

// Initialize the Google AI plugin. When you pass nil for the
// Initialize Genkit with the Google AI plugin. When you pass nil for the
// Config parameter, the Google AI plugin will get the API key from the
// GOOGLE_GENAI_API_KEY environment variable, which is the recommended
// practice.
if err := googleai.Init(ctx, g, nil); err != nil {
g, err := genkit.Init(ctx, genkit.WithPlugins(&googleai.GoogleAI{}))
if err != nil {
log.Fatal(err)
}

Expand Down
2 changes: 1 addition & 1 deletion go/internal/doc-snippets/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func m1() error {
// Default to the value of GCLOUD_PROJECT for the project,
// and "us-central1" for the location.
// To specify these values directly, pass a vertexai.Config value to Init.
if err := vertexai.Init(ctx, g, nil); err != nil {
if err := (&vertexai.VertexAI{}).Init(ctx, g); err != nil {
return err
}
// [END init]
Expand Down
7 changes: 3 additions & 4 deletions go/internal/doc-snippets/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ func ollamaEx(ctx context.Context) error {

// [START init]
// Init with Ollama's default local address.
if err := ollama.Init(ctx, &ollama.Config{
ServerAddress: "http://127.0.0.1:11434",
}); err != nil {
o := &ollama.Ollama{ServerAddress: "http://127.0.0.1:11434"}
if err := o.Init(ctx, g); err != nil {
return err
}
// [END init]

// [START definemodel]
name := "gemma2"
model := ollama.DefineModel(
model := o.DefineModel(
g,
ollama.ModelDefinition{
Name: name,
Expand Down
Loading
Loading