-
Notifications
You must be signed in to change notification settings - Fork 1
/
text_to_image.go
139 lines (116 loc) · 4.18 KB
/
text_to_image.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package stabilityai
import (
"fmt"
"github.com/instill-ai/component/pkg/base"
"google.golang.org/protobuf/types/known/structpb"
)
const (
successFinishReason = "SUCCESS"
textToImagePathTemplate = "/v1/generation/%s/text-to-image"
)
func textToImagePath(engine string) string {
return fmt.Sprintf(textToImagePathTemplate, engine)
}
type TextToImageInput struct {
Task string `json:"task"`
Prompts []string `json:"prompts"`
Engine string `json:"engine"`
Weights *[]float64 `json:"weights,omitempty"`
Height *uint32 `json:"height,omitempty"`
Width *uint32 `json:"width,omitempty"`
CfgScale *float64 `json:"cfg_scale,omitempty"`
ClipGuidancePreset *string `json:"clip_guidance_preset,omitempty"`
Sampler *string `json:"sampler,omitempty"`
Samples *uint32 `json:"samples,omitempty"`
Seed *uint32 `json:"seed,omitempty"`
Steps *uint32 `json:"steps,omitempty"`
StylePreset *string `json:"style_preset,omitempty"`
}
type TextToImageOutput struct {
Images []string `json:"images"`
Seeds []uint32 `json:"seeds"`
}
// TextToImageReq represents the request body for text-to-image API
type TextToImageReq struct {
TextPrompts []TextPrompt `json:"text_prompts" om:"texts[:]"`
CFGScale *float64 `json:"cfg_scale,omitempty" om:"metadata.cfg_scale"`
ClipGuidancePreset *string `json:"clip_guidance_preset,omitempty" om:"metadata.clip_guidance_preset"`
Sampler *string `json:"sampler,omitempty" om:"metadata.sampler"`
Samples *uint32 `json:"samples,omitempty" om:"metadata.samples"`
Seed *uint32 `json:"seed,omitempty" om:"metadata.seed"`
Steps *uint32 `json:"steps,omitempty" om:"metadata.steps"`
StylePreset *string `json:"style_preset,omitempty" om:"metadata.style_preset"`
Height *uint32 `json:"height,omitempty" om:"metadata.height"`
Width *uint32 `json:"width,omitempty" om:"metadata.width"`
path string
}
// TextPrompt holds a prompt's text and its weight.
type TextPrompt struct {
Text string `json:"text" om:"."`
Weight *float64 `json:"weight"`
}
// Image represents a single image.
type Image struct {
Base64 string `json:"base64"`
Seed uint32 `json:"seed"`
FinishReason string `json:"finishReason"`
}
// ImageTaskRes represents the response body for text-to-image API.
type ImageTaskRes struct {
Images []Image `json:"artifacts"`
}
func parseTextToImageReq(from *structpb.Struct) (TextToImageReq, error) {
// Parse from pb.
input := TextToImageInput{}
if err := base.ConvertFromStructpb(from, &input); err != nil {
return TextToImageReq{}, err
}
// Validate input.
nPrompts := len(input.Prompts)
if nPrompts <= 0 {
return TextToImageReq{}, fmt.Errorf("no text prompts given")
}
if input.Engine == "" {
return TextToImageReq{}, fmt.Errorf("no engine selected")
}
// Convert to req.
req := TextToImageReq{
CFGScale: input.CfgScale,
ClipGuidancePreset: input.ClipGuidancePreset,
Sampler: input.Sampler,
Samples: input.Samples,
Seed: input.Seed,
Steps: input.Steps,
StylePreset: input.StylePreset,
Height: input.Height,
Width: input.Width,
path: textToImagePath(input.Engine),
}
req.TextPrompts = make([]TextPrompt, 0, nPrompts)
for index, t := range input.Prompts {
// If weight isn't provided, set to 1.
w := 1.0
if input.Weights != nil && len(*input.Weights) > index {
w = (*input.Weights)[index]
}
req.TextPrompts = append(req.TextPrompts, TextPrompt{
Text: t,
Weight: &w,
})
}
return req, nil
}
func textToImageOutput(from ImageTaskRes) (*structpb.Struct, error) {
output := TextToImageOutput{
Images: []string{},
Seeds: []uint32{},
}
for _, image := range from.Images {
if image.FinishReason != successFinishReason {
continue
}
output.Images = append(output.Images, fmt.Sprintf("data:image/png;base64,%s", image.Base64))
output.Seeds = append(output.Seeds, image.Seed)
}
return base.ConvertToStructpb(output)
}