Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 54 additions & 75 deletions go/ai/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,146 +31,125 @@ type Document struct {
// A Part is one part of a [Document]. This may be plain text or it
// may be a URL (possibly a "data:" URL with embedded data).
type Part struct {
kind partKind
contentType string // valid for kind==blob
text string // valid for kind∈{text,blob}
toolRequest *ToolRequest // valid for kind==partToolRequest
toolResponse *ToolResponse // valid for kind==partToolResponse
Kind PartKind `json:"kind,omitempty"`
ContentType string `json:"contentType,omitempty"` // valid for kind==blob
Text string `json:"text,omitempty"` // valid for kind∈{text,blob}
ToolRequest *ToolRequest `json:"toolreq,omitempty"` // valid for kind==partToolRequest
ToolResponse *ToolResponse `json:"toolresp,omitempty"` // valid for kind==partToolResponse
}

type partKind int8
type PartKind int8

const (
partText partKind = iota
partMedia
partData
partToolRequest
partToolResponse
PartText PartKind = iota
PartMedia
PartData
PartToolRequest
PartToolResponse
)

// NewTextPart returns a Part containing text.
func NewTextPart(text string) *Part {
return &Part{kind: partText, text: text}
return &Part{Kind: PartText, ContentType: "plain/text", Text: text}
}

// NewJSONPart returns a Part containing JSON.
func NewJSONPart(text string) *Part {
return &Part{Kind: PartText, ContentType: "application/json", Text: text}
}

// NewMediaPart returns a Part containing structured data described
// by the given mimeType.
func NewMediaPart(mimeType, contents string) *Part {
return &Part{kind: partMedia, contentType: mimeType, text: contents}
return &Part{Kind: PartMedia, ContentType: mimeType, Text: contents}
}

// NewDataPart returns a Part containing raw string data.
func NewDataPart(contents string) *Part {
return &Part{kind: partData, text: contents}
return &Part{Kind: PartData, Text: contents}
}

// NewToolRequestPart returns a Part containing a request from
// the model to the client to run a Tool.
// (Only genkit plugins should need to use this function.)
func NewToolRequestPart(r *ToolRequest) *Part {
return &Part{kind: partToolRequest, toolRequest: r}
return &Part{Kind: PartToolRequest, ToolRequest: r}
}

// NewToolResponsePart returns a Part containing the results
// of applying a Tool that the model requested.
func NewToolResponsePart(r *ToolResponse) *Part {
return &Part{kind: partToolResponse, toolResponse: r}
return &Part{Kind: PartToolResponse, ToolResponse: r}
}

// IsText reports whether the [Part] contains plain text.
func (p *Part) IsText() bool {
return p.kind == partText
return p.Kind == PartText
}

// IsMedia reports whether the [Part] contains structured media data.
func (p *Part) IsMedia() bool {
return p.kind == partMedia
return p.Kind == PartMedia
}

// IsData reports whether the [Part] contains unstructured data.
func (p *Part) IsData() bool {
return p.kind == partData
return p.Kind == PartData
}

// IsToolRequest reports whether the [Part] contains a request to run a tool.
func (p *Part) IsToolRequest() bool {
return p.kind == partToolRequest
return p.Kind == PartToolRequest
}

// IsToolResponse reports whether the [Part] contains the result of running a tool.
func (p *Part) IsToolResponse() bool {
return p.kind == partToolResponse
}

// Text returns the text. This is either plain text or a URL.
func (p *Part) Text() string {
return p.text
}

// ContentType returns the type of the content.
// This is only interesting if IsBlob() is true.
func (p *Part) ContentType() string {
if p.kind == partText {
return "text/plain"
}
return p.contentType
}

// ToolRequest returns a request from the model for the client to run a tool.
// Valid only if [IsToolRequest] is true.
func (p *Part) ToolRequest() *ToolRequest {
return p.toolRequest
}

// ToolResponse returns the tool response the client is sending to the model.
// Valid only if [IsToolResponse] is true.
func (p *Part) ToolResponse() *ToolResponse {
return p.toolResponse
return p.Kind == PartToolResponse
}

// MarshalJSON is called by the JSON marshaler to write out a Part.
func (p *Part) MarshalJSON() ([]byte, error) {
// This is not handled by the schema generator because
// Part is defined in TypeScript as a union.

switch p.kind {
case partText:
switch p.Kind {
case PartText:
v := textPart{
Text: p.text,
Text: p.Text,
}
return json.Marshal(v)
case partMedia:
case PartMedia:
v := mediaPart{
Media: &mediaPartMedia{
ContentType: p.contentType,
Url: p.text,
ContentType: p.ContentType,
Url: p.Text,
},
}
return json.Marshal(v)
case partData:
case PartData:
v := dataPart{
Data: p.text,
Data: p.Text,
}
return json.Marshal(v)
case partToolRequest:
case PartToolRequest:
// TODO: make sure these types marshal/unmarshal nicely
// between Go and javascript. At the very least the
// field name needs to change (here and in UnmarshalJSON).
v := struct {
ToolReq *ToolRequest `json:"toolreq,omitempty"`
}{
ToolReq: p.toolRequest,
ToolReq: p.ToolRequest,
}
return json.Marshal(v)
case partToolResponse:
case PartToolResponse:
v := struct {
ToolResp *ToolResponse `json:"toolresp,omitempty"`
}{
ToolResp: p.toolResponse,
ToolResp: p.ToolResponse,
}
return json.Marshal(v)
default:
return nil, fmt.Errorf("invalid part kind %v", p.kind)
return nil, fmt.Errorf("invalid part kind %v", p.Kind)
}
}

Expand All @@ -193,23 +172,23 @@ func (p *Part) UnmarshalJSON(b []byte) error {

switch {
case s.Media != nil:
p.kind = partMedia
p.text = s.Media.Url
p.contentType = s.Media.ContentType
p.Kind = PartMedia
p.Text = s.Media.Url
p.ContentType = s.Media.ContentType
case s.ToolReq != nil:
p.kind = partToolRequest
p.toolRequest = s.ToolReq
p.Kind = PartToolRequest
p.ToolRequest = s.ToolReq
case s.ToolResp != nil:
p.kind = partToolResponse
p.toolResponse = s.ToolResp
p.Kind = PartToolResponse
p.ToolResponse = s.ToolResp
default:
p.kind = partText
p.text = s.Text
p.contentType = ""
p.Kind = PartText
p.Text = s.Text
p.ContentType = ""
if s.Data != "" {
// Note: if part is completely empty, we use text by default.
p.kind = partData
p.text = s.Data
p.Kind = PartData
p.Text = s.Data
}
}
return nil
Expand All @@ -220,9 +199,9 @@ func (p *Part) UnmarshalJSON(b []byte) error {
func DocumentFromText(text string, metadata map[string]any) *Document {
return &Document{
Content: []*Part{
&Part{
kind: partText,
text: text,
{
Kind: PartText,
Text: text,
},
},
Metadata: metadata,
Expand Down
50 changes: 25 additions & 25 deletions go/ai/document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestDocumentFromText(t *testing.T) {
if !p.IsText() {
t.Errorf("IsText() == %t, want %t", p.IsText(), true)
}
if got := p.Text(); got != data {
if got := p.Text; got != data {
t.Errorf("Data() == %q, want %q", got, data)
}
}
Expand All @@ -42,28 +42,28 @@ func TestDocumentJSON(t *testing.T) {
d := Document{
Content: []*Part{
&Part{
kind: partText,
text: "hi",
Kind: PartText,
Text: "hi",
},
&Part{
kind: partMedia,
contentType: "text/plain",
text: "data:,bye",
Kind: PartMedia,
ContentType: "text/plain",
Text: "data:,bye",
},
&Part{
kind: partData,
text: "somedata\x00string",
Kind: PartData,
Text: "somedata\x00string",
},
&Part{
kind: partToolRequest,
toolRequest: &ToolRequest{
Kind: PartToolRequest,
ToolRequest: &ToolRequest{
Name: "tool1",
Input: map[string]any{"arg1": 3.3, "arg2": "foo"},
},
},
&Part{
kind: partToolResponse,
toolResponse: &ToolResponse{
Kind: PartToolResponse,
ToolResponse: &ToolResponse{
Name: "tool1",
Output: map[string]any{"res1": 4.4, "res2": "bar"},
},
Expand All @@ -83,22 +83,22 @@ func TestDocumentJSON(t *testing.T) {
}

cmpPart := func(a, b *Part) bool {
if a.kind != b.kind {
if a.Kind != b.Kind {
return false
}
switch a.kind {
case partText:
return a.text == b.text
case partMedia:
return a.contentType == b.contentType && a.text == b.text
case partData:
return a.text == b.text
case partToolRequest:
return reflect.DeepEqual(a.toolRequest, b.toolRequest)
case partToolResponse:
return reflect.DeepEqual(a.toolResponse, b.toolResponse)
switch a.Kind {
case PartText:
return a.Text == b.Text
case PartMedia:
return a.ContentType == b.ContentType && a.Text == b.Text
case PartData:
return a.Text == b.Text
case PartToolRequest:
return reflect.DeepEqual(a.ToolRequest, b.ToolRequest)
case PartToolResponse:
return reflect.DeepEqual(a.ToolResponse, b.ToolResponse)
default:
t.Fatalf("bad part kind %v", a.kind)
t.Fatalf("bad part kind %v", a.Kind)
return false
}
}
Expand Down
Loading