/
EmbeddingsHelper.java
122 lines (96 loc) · 4.16 KB
/
EmbeddingsHelper.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package us.careydevelopment.ai.openai.support;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.service.OpenAiService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class EmbeddingsHelper {
private static Logger LOG = LoggerFactory.getLogger(EmbeddingsHelper.class);
private static final String DEFAULT_MODEL = "text-embedding-ada-002";
private static EmbeddingRequest getEmbeddingRequest(final List<String> inputText) {
final EmbeddingRequest request = EmbeddingRequest
.builder()
.model(DEFAULT_MODEL)
.input(inputText)
.build();
return request;
}
public static List<Embedding> getEmbeddings(final List<String> inputText) {
final OpenAiService service = OpenAiServiceHelper.getOpenAiService();
final EmbeddingRequest request = getEmbeddingRequest(inputText);
final List<Embedding> embeddings = service
.createEmbeddings(request)
.getData();
return embeddings;
}
public static List<Embedding> getEmbeddingsFromFile(final String filePath) throws IOException {
try (final Stream<String> stream = Files.lines(Paths.get(filePath))) {
final List<String> lines = stream.collect(Collectors.toList());
return getEmbeddings(lines);
}
}
public static boolean persistEmbeddingsFromFile(final String inputPath, final String outputPath) {
try {
final List<Embedding> embeddings = getEmbeddingsFromFile(inputPath);
saveFile(embeddings, outputPath);
} catch (Exception e) {
LOG.error("Problem saving embeddings!", e);
return false;
}
return true;
}
public static boolean persistEmbeddings(final List<String> inputText, final String pathStr) {
try {
final List<Embedding> embeddings = getEmbeddings(inputText);
saveFile(embeddings, pathStr);
} catch (Exception e) {
LOG.error("Problem saving embeddings!", e);
return false;
}
return true;
}
private static void saveFile(final List<Embedding> embeddings, final String pathStr) throws IOException {
final List<String> output = getEmbeddingsAsStringList(embeddings);
final Path path = Paths.get(pathStr);
Files.write(path, output);
}
static List<String> getEmbeddingsAsStringList(final List<Embedding> embeddings) {
final List<String> list = embeddings
.stream()
.map(e -> e.getEmbedding())
.map(ld -> ld.stream()
.map(dd -> Double.toString(dd))
.collect(Collectors.joining(",")))
.collect(Collectors.toList());
return list;
}
public static List<List<Double>> loadFromFile(final String pathStr) throws IOException {
final List<List<Double>> matrix = new ArrayList<>();
try (Stream<String> stream = Files.lines(Paths.get(pathStr))) {
matrix.addAll(stream
.map(line -> Arrays.asList(line.split(",")))
.map(list -> list.stream().map(s -> Double.parseDouble(s)).collect(Collectors.toList()))
.collect(Collectors.toList()));
}
return matrix;
}
public static List<List<Float>> loadFromFileAsFloats(final String pathStr) throws IOException {
final List<List<Float>> matrix = new ArrayList<>();
try (Stream<String> stream = Files.lines(Paths.get(pathStr))) {
matrix.addAll(stream
.map(line -> Arrays.asList(line.split(",")))
.map(list -> list.stream().map(s -> Float.parseFloat(s)).collect(Collectors.toList()))
.collect(Collectors.toList()));
}
return matrix;
}
}