-
Notifications
You must be signed in to change notification settings - Fork 0
/
splitter.go
115 lines (96 loc) · 2.42 KB
/
splitter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package llmflow
import (
"context"
"fmt"
"strings"
"github.com/RussellLuo/orchestrator"
)
const TypeSplitter = "splitter"
func init() {
MustRegisterSplitter(orchestrator.GlobalRegistry)
}
func MustRegisterSplitter(r *orchestrator.Registry) {
r.MustRegister(&orchestrator.TaskFactory{
Type: TypeSplitter,
New: func() orchestrator.Task { return new(Splitter) },
})
}
type Splitter struct {
orchestrator.TaskHeader
Input struct {
Documents orchestrator.Expr[[]*Document] `json:"documents"`
SplitChars []string `json:"split_chars"`
ChunkSize int `json:"chunk_size"`
} `json:"input"`
}
func (s *Splitter) Init(r *orchestrator.Registry) error {
if len(s.Input.SplitChars) == 0 {
s.Input.SplitChars = []string{"\n", "。", "!", "?"}
}
return nil
}
func (s *Splitter) String() string {
return fmt.Sprintf("%s(name:%s)", s.Type, s.Name)
}
func (s *Splitter) Execute(ctx context.Context, input orchestrator.Input) (orchestrator.Output, error) {
if err := s.Input.Documents.Evaluate(input); err != nil {
return nil, err
}
var docs []*Document
for _, doc := range s.Input.Documents.Value {
for i, chunk := range s.split(doc.Text) {
docs = append(docs, &Document{
ID: fmt.Sprintf("%s-%d", doc.ID, i),
Text: chunk,
Metadata: Metadata{
SourceID: doc.Metadata.SourceID,
},
})
}
}
output, err := orchestrator.DefaultCodec.Encode(struct {
Documents []*Document `json:"documents"`
}{
Documents: docs,
})
if err != nil {
return nil, err
}
//fmt.Printf("loader output: %#v\n", output)
return output.(map[string]any), nil
}
func (s *Splitter) split(text string) []string {
parts := strings.FieldsFunc(text, func(r rune) bool {
for _, char := range s.Input.SplitChars {
for _, c := range []rune(char) {
if c == r {
return true
}
}
}
return false
})
var chunks []string
var curChunk string
for i := 0; i < len(parts)-1; i++ {
if len(curChunk) > 0 && len(curChunk)+len(parts[i]) > s.Input.ChunkSize {
chunks = append(chunks, curChunk)
curChunk = ""
}
curChunk += parts[i]
}
return chunks
}
type SplitterBuilder struct {
task *Splitter
}
func NewSplitter(name string) *SplitterBuilder {
task := &Splitter{
TaskHeader: orchestrator.TaskHeader{
Name: name,
Type: TypeSplitter,
},
}
return &SplitterBuilder{task: task}
}
func (b *SplitterBuilder) Build() orchestrator.Task { return b.task }