Skip to content

Commit 632b7cb

Browse files
Contain task type compatibility in TransportInferenceUsageAction
1 parent 809631b commit 632b7cb

File tree

5 files changed

+35
-46
lines changed

5 files changed

+35
-46
lines changed

server/src/main/java/org/elasticsearch/inference/TaskType.java

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import java.util.Objects;
2222

2323
public enum TaskType implements Writeable {
24-
TEXT_EMBEDDING(true),
25-
SPARSE_EMBEDDING(true),
24+
TEXT_EMBEDDING,
25+
SPARSE_EMBEDDING,
2626
RERANK,
2727
COMPLETION,
2828
ANY {
@@ -52,16 +52,6 @@ public static TaskType fromStringOrStatusException(String name) {
5252
}
5353
}
5454

55-
private final boolean isCompatibleWithSemanticText;
56-
57-
TaskType(boolean isCompatibleWithSemanticText) {
58-
this.isCompatibleWithSemanticText = isCompatibleWithSemanticText;
59-
}
60-
61-
TaskType() {
62-
this(false);
63-
}
64-
6555
/**
6656
* Return true if the {@code other} is the {@link #ANY} type
6757
* or the same as this.
@@ -72,14 +62,6 @@ public boolean isAnyOrSame(TaskType other) {
7262
return other == TaskType.ANY || other == this;
7363
}
7464

75-
/**
76-
* Returns true if this task type is compatible with semantic text.
77-
* @return True if this task type is compatible with semantic text.
78-
*/
79-
public boolean isCompatibleWithSemanticText() {
80-
return isCompatibleWithSemanticText;
81-
}
82-
8365
@Override
8466
public String toString() {
8567
return name().toLowerCase(Locale.ROOT);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/ModelStats.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,6 @@ public class ModelStats implements ToXContentObject, Writeable {
3434
@Nullable
3535
private final SemanticTextStats semanticTextStats;
3636

37-
public ModelStats(String service, TaskType taskType) {
38-
this(service, taskType, 0L);
39-
}
40-
41-
public ModelStats(String service, TaskType taskType, long count) {
42-
this(service, taskType, count, taskType.isCompatibleWithSemanticText() ? new SemanticTextStats() : null);
43-
}
44-
4537
public ModelStats(String service, TaskType taskType, long count, @Nullable SemanticTextStats semanticTextStats) {
4638
this.service = service;
4739
this.taskType = taskType;

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ protected ModelStats mutateInstance(ModelStats modelStats) throws IOException {
5555
}
5656

5757
public void testAdd() {
58-
ModelStats stats = new ModelStats("test_service", randomFrom(TaskType.values()));
58+
ModelStats stats = new ModelStats("test_service", randomFrom(TaskType.values()), 0, null);
5959
assertThat(stats.count(), equalTo(0L));
6060

6161
stats.add();
@@ -74,7 +74,7 @@ public static ModelStats createRandomInstance() {
7474
randomIdentifier(),
7575
taskType,
7676
randomLong(),
77-
taskType.isCompatibleWithSemanticText() ? SemanticTextStatsTests.createRandomInstance() : null
77+
randomBoolean() ? SemanticTextStatsTests.createRandomInstance() : null
7878
);
7979
}
8080

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@
3232
import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage;
3333
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
3434
import org.elasticsearch.xpack.core.inference.usage.ModelStats;
35+
import org.elasticsearch.xpack.core.inference.usage.SemanticTextStats;
3536
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3637

3738
import java.util.ArrayList;
39+
import java.util.EnumSet;
3840
import java.util.HashMap;
3941
import java.util.HashSet;
4042
import java.util.List;
@@ -55,6 +57,11 @@ public class TransportInferenceUsageAction extends XPackUsageFeatureTransportAct
5557
// Some of the default models have optimized variants for linux that will have the following suffix.
5658
private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64";
5759

60+
private static final EnumSet<TaskType> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT = EnumSet.of(
61+
TaskType.TEXT_EMBEDDING,
62+
TaskType.SPARSE_EMBEDDING
63+
);
64+
5865
private final ModelRegistry modelRegistry;
5966
private final Client client;
6067

@@ -144,12 +151,12 @@ private static void addStatsByServiceAndTask(
144151
for (ModelConfigurations model : endpoints) {
145152
endpointStats.computeIfAbsent(
146153
new ServiceAndTaskType(model.getService(), model.getTaskType()).toString(),
147-
key -> new ModelStats(model.getService(), model.getTaskType())
154+
key -> createEmptyStats(model)
148155
).add();
149156

150157
endpointStats.computeIfAbsent(
151158
new ServiceAndTaskType(Metadata.ALL, model.getTaskType()).toString(),
152-
key -> new ModelStats(Metadata.ALL, model.getTaskType())
159+
key -> createEmptyStats(Metadata.ALL, model.getTaskType())
153160
).add();
154161
}
155162

@@ -162,6 +169,19 @@ private static void addStatsByServiceAndTask(
162169
addTopLevelStatsByTask(inferenceFieldsByIndexServiceAndTask, endpointStats);
163170
}
164171

172+
private static ModelStats createEmptyStats(ModelConfigurations model) {
173+
return createEmptyStats(model.getService(), model.getTaskType());
174+
}
175+
176+
private static ModelStats createEmptyStats(String service, TaskType taskType) {
177+
return new ModelStats(
178+
service,
179+
taskType,
180+
0,
181+
TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(taskType) ? new SemanticTextStats() : null
182+
);
183+
}
184+
165185
private static void addTopLevelStatsByTask(
166186
Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldsByIndexServiceAndTask,
167187
Map<String, ModelStats> endpointStats
@@ -172,9 +192,9 @@ private static void addTopLevelStatsByTask(
172192
}
173193
ModelStats allStatsForTaskType = endpointStats.computeIfAbsent(
174194
new ServiceAndTaskType(Metadata.ALL, taskType).toString(),
175-
key -> new ModelStats(Metadata.ALL, taskType)
195+
key -> createEmptyStats(Metadata.ALL, taskType)
176196
);
177-
if (taskType.isCompatibleWithSemanticText()) {
197+
if (TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(taskType)) {
178198
Map<String, List<InferenceFieldMetadata>> inferenceFieldsByIndex = inferenceFieldsByIndexServiceAndTask.entrySet()
179199
.stream()
180200
.filter(e -> e.getKey().taskType == taskType)
@@ -226,7 +246,12 @@ private void addStatsForDefaultModelsCompatibleWithSemanticText(
226246
// Now that we have all inference fields for this service and task type, we want to keep only the ones that
227247
// reference the current default model.
228248
fieldsByIndex = filterFields(fieldsByIndex, f -> statKey.modelId.equals(endpointIdToModelId.get(f.getInferenceId())));
229-
ModelStats stats = new ModelStats(statKey.toString(), statKey.taskType, defaultModelStatsKeyToEndpointCount.getValue());
249+
ModelStats stats = new ModelStats(
250+
statKey.toString(),
251+
statKey.taskType,
252+
defaultModelStatsKeyToEndpointCount.getValue(),
253+
new SemanticTextStats()
254+
);
230255
addSemanticTextStats(fieldsByIndex, stats);
231256
endpointStats.put(statKey.toString(), stats);
232257
}
@@ -239,7 +264,7 @@ private Map<DefaultModelStatsKey, Long> createStatsKeysWithEndpointCountsForDefa
239264
// Note that endpoints could have a null model id, in which case we don't consider them default as this
240265
// may only happen for external services.
241266
Set<String> modelIds = endpoints.stream()
242-
.filter(endpoint -> endpoint.getTaskType().isCompatibleWithSemanticText())
267+
.filter(endpoint -> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(endpoint.getTaskType()))
243268
.filter(endpoint -> modelRegistry.containsDefaultConfigId(endpoint.getInferenceEntityId()))
244269
.filter(endpoint -> endpoint.getServiceSettings().modelId() != null)
245270
.map(endpoint -> stripLinuxSuffix(endpoint.getServiceSettings().modelId()))

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import org.elasticsearch.test.ESTestCase;
1313
import org.hamcrest.Matchers;
1414

15-
import static org.hamcrest.core.Is.is;
16-
1715
public class TaskTypeTests extends ESTestCase {
1816

1917
public void testFromStringOrStatusException() {
@@ -26,12 +24,4 @@ public void testFromStringOrStatusException() {
2624
assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY));
2725
}
2826

29-
public void testIsCompatibleWithSemanticText() {
30-
assertThat(TaskType.ANY.isCompatibleWithSemanticText(), is(false));
31-
assertThat(TaskType.CHAT_COMPLETION.isCompatibleWithSemanticText(), is(false));
32-
assertThat(TaskType.COMPLETION.isCompatibleWithSemanticText(), is(false));
33-
assertThat(TaskType.RERANK.isCompatibleWithSemanticText(), is(false));
34-
assertThat(TaskType.TEXT_EMBEDDING.isCompatibleWithSemanticText(), is(true));
35-
assertThat(TaskType.SPARSE_EMBEDDING.isCompatibleWithSemanticText(), is(true));
36-
}
3727
}

0 commit comments

Comments
 (0)