Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 103 additions & 18 deletions go/ai/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ai

import (
"encoding/json"
"fmt"
)

// A Document is a piece of data that can be embedded, indexed, or retrieved.
Expand All @@ -30,21 +31,64 @@ type Document struct {
// A Part is one part of a [Document]. This may be plain text or it
// may be a URL (possibly a "data:" URL with embedded data).
type Part struct {
isText bool
contentType string
text string
kind partKind
contentType string // valid for kind==blob
text string // valid for kind∈{text,blob}
toolRequest *ToolRequest // valid for kind==partToolRequest
toolResponse *ToolResponse // valid for kind==partToolResponse
}

type partKind int8

const (
partText partKind = iota
partBlob
partToolRequest
partToolResponse
)

// NewTextPart returns a Part containing raw string data.
func NewTextPart(text string) *Part {
return &Part{isText: true, text: text}
return &Part{kind: partText, text: text}
}

// NewBlobPart returns a Part containing structured data described
// by the given mimeType.
func NewBlobPart(mimeType, contents string) *Part {
return &Part{isText: false, contentType: mimeType, text: contents}
return &Part{kind: partBlob, contentType: mimeType, text: contents}
}

// NewToolRequestPart returns a Part containing a request from
// the model to the client to run a Tool.
// (Only genkit plugins should need to use this function.)
func NewToolRequestPart(r *ToolRequest) *Part {
return &Part{kind: partToolRequest, toolRequest: r}
}

// NewToolResponsePart returns a Part containing the results
// of applying a Tool that the model requested.
func NewToolResponsePart(r *ToolResponse) *Part {
return &Part{kind: partToolResponse, toolResponse: r}
}

// IsText reports whether the [Part] contains plain text.
func (p *Part) IsPlainText() bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be IsText to match the others.

return p.isText
return p.kind == partText
}

// IsBlob reports whether the [Part] contains blob (non-plain-text) data.
func (p *Part) IsBlob() bool {
return p.kind == partBlob
}

// IsToolRequest reports whether the [Part] contains a request to run a tool.
func (p *Part) IsToolRequest() bool {
return p.kind == partToolRequest
}

// IsToolResponse reports whether the [Part] contains the result of running a tool.
func (p *Part) IsToolResponse() bool {
return p.kind == partToolResponse
}

// Text returns the text. This is either plain text or a URL.
Expand All @@ -53,32 +97,64 @@ func (p *Part) Text() string {
}

// ContentType returns the type of the content.
// This is only interesting if IsText is false.
// This is only interesting if IsBlob() is true.
func (p *Part) ContentType() string {
if p.isText {
if p.kind == partText {
return "text/plain"
}
return p.contentType
}

// ToolRequest returns a request from the model for the client to run a tool.
// Valid only if [IsToolRequest] is true.
func (p *Part) ToolRequest() *ToolRequest {
return p.toolRequest
}

// ToolResponse returns the tool response the client is sending to the model.
// Valid only if [IsToolResponse] is true.
func (p *Part) ToolResponse() *ToolResponse {
return p.toolResponse
}

// MarshalJSON is called by the JSON marshaler to write out a Part.
func (p *Part) MarshalJSON() ([]byte, error) {
// This is not handled by the schema generator because
// Part is defined in TypeScript as a union.

if p.isText {
switch p.kind {
case partText:
v := textPart{
Text: p.text,
}
return json.Marshal(v)
} else {
case partBlob:
v := mediaPart{
Media: &mediaPartMedia{
ContentType: p.contentType,
Url: p.text,
},
}
return json.Marshal(v)
case partToolRequest:
// TODO: make sure these types marshal/unmarshal nicely
// between Go and javascript. At the very least the
// field name needs to change (here and in UnmarshalJSON).
v := struct {
ToolReq *ToolRequest `json:"toolreq,omitempty"`
}{
ToolReq: p.toolRequest,
}
return json.Marshal(v)
case partToolResponse:
v := struct {
ToolResp *ToolResponse `json:"toolresp,omitempty"`
}{
ToolResp: p.toolResponse,
}
return json.Marshal(v)
default:
return nil, fmt.Errorf("invalid part kind %v", p.kind)
}
}

Expand All @@ -88,20 +164,29 @@ func (p *Part) UnmarshalJSON(b []byte) error {
// Part is defined in TypeScript as a union.

var s struct {
Text string `json:"text,omitempty"`
Media *mediaPartMedia `json:"media,omitempty"`
Text string `json:"text,omitempty"`
Media *mediaPartMedia `json:"media,omitempty"`
ToolReq *ToolRequest `json:"toolreq,omitempty"`
ToolResp *ToolResponse `json:"toolresp,omitempty"`
}

if err := json.Unmarshal(b, &s); err != nil {
return err
}

if s.Media != nil {
p.isText = false
switch {
case s.Media != nil:
p.kind = partBlob
p.text = s.Media.Url
p.contentType = s.Media.ContentType
} else {
p.isText = true
case s.ToolReq != nil:
p.kind = partToolRequest
p.toolRequest = s.ToolReq
case s.ToolResp != nil:
p.kind = partToolResponse
p.toolResponse = s.ToolResp
default:
p.kind = partText
p.text = s.Text
p.contentType = ""
}
Expand All @@ -114,8 +199,8 @@ func DocumentFromText(text string, metadata map[string]any) *Document {
return &Document{
Content: []*Part{
&Part{
isText: true,
text: text,
kind: partText,
text: text,
},
},
Metadata: metadata,
Expand Down
36 changes: 30 additions & 6 deletions go/ai/document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ai

import (
"encoding/json"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -41,13 +42,27 @@ func TestDocumentJSON(t *testing.T) {
d := Document{
Content: []*Part{
&Part{
isText: true,
kind: partText,
text: "hi",
},
&Part{
isText: false,
kind: partBlob,
contentType: "text/plain",
text: "data:,bye",
text: "data:,bye",
},
&Part{
kind: partToolRequest,
toolRequest: &ToolRequest{
Name: "tool1",
Input: map[string]any{"arg1": 3.3, "arg2": "foo"},
},
},
&Part{
kind: partToolResponse,
toolResponse: &ToolResponse{
Name: "tool1",
Output: map[string]any{"res1": 4.4, "res2": "bar"},
},
},
},
}
Expand All @@ -56,20 +71,29 @@ func TestDocumentJSON(t *testing.T) {
if err != nil {
t.Fatal(err)
}
t.Logf("marshaled:%s\n", string(b))

var d2 Document
if err := json.Unmarshal(b, &d2); err != nil {
t.Fatal(err)
}

cmpPart := func(a, b *Part) bool {
if a.isText != b.isText {
if a.kind != b.kind {
return false
}
if a.isText {
switch a.kind {
case partText:
return a.text == b.text
} else {
case partBlob:
return a.contentType == b.contentType && a.text == b.text
case partToolRequest:
return reflect.DeepEqual(a.toolRequest, b.toolRequest)
case partToolResponse:
return reflect.DeepEqual(a.toolResponse, b.toolResponse)
default:
t.Fatalf("bad part kind %v", a.kind)
return false
}
}

Expand Down
32 changes: 16 additions & 16 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,22 @@ type ToolDefinition struct {
OutputSchema map[string]any `json:"outputSchema,omitempty"`
}

type ToolRequestPart struct {
ToolRequest *ToolRequestPartToolRequest `json:"toolRequest,omitempty"`
// A ToolRequest is a request from the model that the client should run
// a specific tool and pass a [ToolResponse] to the model on the next request it makes.
// Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client.
type ToolRequest struct {
// Input is a JSON object describing the input values to the tool.
// An example might be map[string]any{"country":"USA", "president":3}.
Input map[string]any `json:"input,omitempty"`
Name string `json:"name,omitempty"`
}

type ToolRequestPartToolRequest struct {
Input any `json:"input,omitempty"`
Name string `json:"name,omitempty"`
Ref string `json:"ref,omitempty"`
}

type ToolResponsePart struct {
ToolResponse *ToolResponsePartToolResponse `json:"toolResponse,omitempty"`
}

type ToolResponsePartToolResponse struct {
Name string `json:"name,omitempty"`
Output any `json:"output,omitempty"`
Ref string `json:"ref,omitempty"`
// A ToolResponse is a response from the client to the model containing
// the results of running a specific tool on the arguments passed to the client
// by the model in a [ToolRequest].
type ToolResponse struct {
Name string `json:"name,omitempty"`
// Output is a JSON object describing the results of running the tool.
// An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}.
Output map[string]any `json:"output,omitempty"`
}
29 changes: 29 additions & 0 deletions go/genkit/schemas.config
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,35 @@ RoleTool indicates this message was generated by a local tool, likely triggered
from the model in one of its previous responses.
.

ToolRequestPart omit
ToolRequestPartToolRequest name ToolRequest
ToolResponsePart omit
ToolResponsePartToolResponse name ToolResponse

ToolRequestPartToolRequest.input type map[string]any
ToolRequestPartToolRequest.input doc
Input is a JSON object describing the input values to the tool.
An example might be map[string]any{"country":"USA", "president":3}.
.
ToolResponsePartToolResponse.output type map[string]any
ToolResponsePartToolResponse.output doc
Output is a JSON object describing the results of running the tool.
An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}.
.
ToolRequestPartToolRequest.ref omit
ToolResponsePartToolResponse.ref omit

ToolRequestPartToolRequest doc
A ToolRequest is a message from the model to the client that it should run a
specific tool and pass a [ToolResponse] to the model on the next chat request it makes.
Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client.
.
ToolResponsePartToolResponse doc
A ToolResponse is a message from the client to the model containing
the results of running a specific tool on the arguments passed to the client
by the model in a [ToolRequest].
.

Candidate pkg ai
CandidateFinishReason pkg ai
DocumentData pkg ai
Expand Down
3 changes: 3 additions & 0 deletions go/internal/cmd/jsonschemagen/jsonschemagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err
if fs.Not != nil {
continue
}
if fcfg.omit {
continue
}
typeExpr := fcfg.typeExpr
if typeExpr == "" {
var err error
Expand Down
Loading