-
-
Notifications
You must be signed in to change notification settings - Fork 98
/
common.ts
100 lines (78 loc) · 2.3 KB
/
common.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import { AppSchema } from '../../db/schema'
/**
* A single response can contain multiple end tokens
*
* Find the first occurrence of an end token then return the text preceding it
*/
export function trimResponse(
generated: string,
char: AppSchema.Character,
members: AppSchema.Profile[],
endTokens: string[] = []
) {
const allEndTokens = getEndTokens(char, members)
let index = -1
const trimmed = allEndTokens.concat(...endTokens).reduce((prev, endToken) => {
const idx = generated.indexOf(endToken)
if (idx === -1) return prev
const text = generated.slice(0, idx)
if (index === -1 || idx < index) {
index = idx
return text
}
return prev
}, '')
if (index === -1) {
return sanitise(generated)
}
return sanitise(trimmed)
}
export function trimResponseV2(
generated: string,
char: AppSchema.Character,
members: AppSchema.Profile[],
endTokens: string[] = []
) {
const allEndTokens = getEndTokens(null, members)
generated = generated.split(`${char.name} :`).join(`${char.name}:`)
for (const member of members) {
if (!member.handle) continue
generated = generated.split(`${member.handle} :`).join(`${member.handle}:`)
}
let index = -1
let trimmed = allEndTokens.concat(...endTokens).reduce((prev, endToken) => {
const idx = generated.indexOf(endToken)
if (idx === -1) return prev
const text = generated.slice(0, idx)
if (index === -1 || idx < index) {
index = idx
return text
}
return prev
}, '')
if (index === -1) {
return sanitise(generated.split(`${char.name}:`).join(''))
}
return sanitise(trimmed.split(`${char.name}:`).join(''))
}
export function getEndTokens(
char: AppSchema.Character | null,
members: AppSchema.Profile[],
endTokens: string[] = []
) {
const baseEndTokens = ['END_OF_DIALOG', '<END>', '\n\n'].concat(endTokens)
if (char) {
baseEndTokens.push(`${char.name}:`, `${char.name} :`)
}
for (const member of members) {
baseEndTokens.push(`${member.handle}:`, `${member.handle} :`)
}
const uniqueTokens = Array.from(new Set(baseEndTokens))
return uniqueTokens
}
export function joinParts(parts: string[]) {
return parts.map(sanitise).join(' ').trim()
}
export function sanitise(generated: string) {
return generated.replace(/\s+/g, ' ').trim()
}