forked from tmc/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
templates.go
83 lines (70 loc) · 2.44 KB
/
templates.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package prompts
import (
"errors"
"fmt"
"strings"
"text/template"
"github.com/Masterminds/sprig/v3"
"golang.org/x/exp/maps"
)
// ErrInvalidTemplateFormat is the error when the template format is invalid and
// not supported.
var ErrInvalidTemplateFormat = errors.New("invalid template format")
// TemplateFormat is the format of the template.
type TemplateFormat string
const (
// TemplateFormatGoTemplate is the format for go-template.
TemplateFormatGoTemplate TemplateFormat = "go-template"
)
// interpolator is the function that interpolates the given template with the given values.
type interpolator func(template string, values map[string]any) (string, error)
// defaultFormatterMapping is the default mapping of TemplateFormat to interpolator.
var defaultformatterMapping = map[TemplateFormat]interpolator{ //nolint:gochecknoglobals
TemplateFormatGoTemplate: interpolateGoTemplate,
}
// interpolateGoTemplate interpolates the given template with the given values by using
// text/template.
func interpolateGoTemplate(tmpl string, values map[string]any) (string, error) {
parsedTmpl, err := template.New("template").
Option("missingkey=error").
Funcs(sprig.FuncMap()).
Parse(tmpl)
if err != nil {
return "", err
}
sb := new(strings.Builder)
err = parsedTmpl.Execute(sb, values)
if err != nil {
return "", err
}
return sb.String(), nil
}
func newInvalidTemplateError(gotTemplateFormat TemplateFormat) error {
return fmt.Errorf("%w, got: %s, should be one of %s",
ErrInvalidTemplateFormat,
gotTemplateFormat,
maps.Keys(defaultformatterMapping),
)
}
// CheckValidTemplate checks if the template is valid through checking whether the given
// TemplateFormat is available and whether the template can be rendered.
func CheckValidTemplate(template string, templateFormat TemplateFormat, inputVariables []string) error {
_, ok := defaultformatterMapping[templateFormat]
if !ok {
return newInvalidTemplateError(templateFormat)
}
dummyInputs := make(map[string]any, len(inputVariables))
for _, v := range inputVariables {
dummyInputs[v] = "foo"
}
_, err := RenderTemplate(template, templateFormat, dummyInputs)
return err
}
// RenderTemplate renders the template with the given values.
func RenderTemplate(tmpl string, tmplFormat TemplateFormat, values map[string]any) (string, error) {
formatter, ok := defaultformatterMapping[tmplFormat]
if !ok {
return "", newInvalidTemplateError(tmplFormat)
}
return formatter(tmpl, values)
}