-
Notifications
You must be signed in to change notification settings - Fork 1
/
context.go
128 lines (106 loc) · 3.64 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package whisper
// #cgo linux LDFLAGS: -l:libwhisper.a -lm -lstdc++
// #cgo darwin LDFLAGS: -lwhisper -lstdc++ -framework Accelerate
// #include <whisper.h>
// #include <stdlib.h>
import "C"
import (
"fmt"
"log/slog"
"os"
"runtime"
"unsafe"
"github.com/mattermost/calls-transcriber/cmd/transcriber/transcribe"
)
type Config struct {
// The path to the GGML model file to use.
ModelFile string
// The number of system threads to use to perform the transcription.
NumThreads int
// Whether or not past transcription should be used as prompt.
NoContext bool
// 512 = a bit more than 10s. Use multiples of 64. Results in a speedup of 3x at 512, b/c whisper was tuned for 30s chunks. See: https://github.com/ggerganov/whisper.cpp/pull/141
// TODO: tests, validation
AudioContext int
// Whether or not to print progress to stdout (default false).
PrintProgress bool
// Language to use (defaults to autodetection).
Language string
// Whether or not to generate a single segment (default false).
SingleSegment bool
}
func (c Config) IsValid() error {
if c == (Config{}) {
return fmt.Errorf("invalid empty config")
}
if c.ModelFile == "" {
return fmt.Errorf("invalid ModelFile: should not be empty")
}
if _, err := os.Stat(c.ModelFile); err != nil {
return fmt.Errorf("invalid ModelFile: failed to stat model file: %w", err)
}
if numCPU := runtime.NumCPU(); c.NumThreads == 0 || c.NumThreads > numCPU {
return fmt.Errorf("invalid NumThreads: should be in the range [1, %d]", numCPU)
}
return nil
}
type Context struct {
cfg Config
ctx *C.struct_whisper_context
params C.struct_whisper_full_params
}
func NewContext(cfg Config) (*Context, error) {
var c Context
if err := cfg.IsValid(); err != nil {
return nil, fmt.Errorf("failed to validate config: %w", err)
}
c.cfg = cfg
slog.Debug("creating transcription context", slog.Any("cfg", cfg))
// TODO: verify whether there's any potential optimizations
// that could be made by using lower level initialization methods
// such as whisper_init or whisper_init_from_buffer.
path := C.CString(cfg.ModelFile)
defer C.free(unsafe.Pointer(path))
c.ctx = C.whisper_init_from_file(path)
if c.ctx == nil {
return nil, fmt.Errorf("failed to load model file")
}
c.params = C.whisper_full_default_params(C.WHISPER_SAMPLING_GREEDY)
c.params.no_context = C.bool(c.cfg.NoContext)
c.params.audio_ctx = C.int(c.cfg.AudioContext)
c.params.n_threads = C.int(c.cfg.NumThreads)
if c.cfg.Language == "" {
c.cfg.Language = "auto"
}
c.params.language = C.CString(c.cfg.Language)
c.params.single_segment = C.bool(c.cfg.SingleSegment)
c.params.print_progress = C.bool(c.cfg.PrintProgress)
return &c, nil
}
func (c *Context) Destroy() error {
if c.ctx == nil {
return fmt.Errorf("context is not initialized")
}
C.whisper_free(c.ctx)
C.free(unsafe.Pointer(c.params.language))
c.ctx = nil
return nil
}
func (c *Context) Transcribe(samples []float32) ([]transcribe.Segment, string, error) {
if len(samples) == 0 {
return nil, "", fmt.Errorf("samples should not be empty")
}
ret := C.whisper_full(c.ctx, c.params, (*C.float)(&samples[0]), C.int(len(samples)))
if ret != 0 {
return nil, "", fmt.Errorf("whisper_full failed with code %d", ret)
}
lang := C.GoString(C.whisper_lang_str(C.whisper_full_lang_id(c.ctx)))
n := int(C.whisper_full_n_segments(c.ctx))
segments := make([]transcribe.Segment, n)
for i := 0; i < n; i++ {
segments[i].Text = C.GoString(C.whisper_full_get_segment_text(c.ctx, C.int(i)))
segments[i].StartTS = int64(C.whisper_full_get_segment_t0(c.ctx, C.int(i))) * 10
segments[i].EndTS = int64(C.whisper_full_get_segment_t1(c.ctx, C.int(i))) * 10
}
return segments, lang, nil
}