3232import org .elasticsearch .xpack .core .inference .InferenceFeatureSetUsage ;
3333import org .elasticsearch .xpack .core .inference .action .GetInferenceModelAction ;
3434import org .elasticsearch .xpack .core .inference .usage .ModelStats ;
35+ import org .elasticsearch .xpack .core .inference .usage .SemanticTextStats ;
3536import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
3637
3738import java .util .ArrayList ;
39+ import java .util .EnumSet ;
3840import java .util .HashMap ;
3941import java .util .HashSet ;
4042import java .util .List ;
@@ -55,6 +57,11 @@ public class TransportInferenceUsageAction extends XPackUsageFeatureTransportAct
5557 // Some of the default models have optimized variants for linux that will have the following suffix.
5658 private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64" ;
5759
60+ private static final EnumSet <TaskType > TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT = EnumSet .of (
61+ TaskType .TEXT_EMBEDDING ,
62+ TaskType .SPARSE_EMBEDDING
63+ );
64+
5865 private final ModelRegistry modelRegistry ;
5966 private final Client client ;
6067
@@ -144,12 +151,12 @@ private static void addStatsByServiceAndTask(
144151 for (ModelConfigurations model : endpoints ) {
145152 endpointStats .computeIfAbsent (
146153 new ServiceAndTaskType (model .getService (), model .getTaskType ()).toString (),
147- key -> new ModelStats (model . getService (), model . getTaskType () )
154+ key -> createEmptyStats (model )
148155 ).add ();
149156
150157 endpointStats .computeIfAbsent (
151158 new ServiceAndTaskType (Metadata .ALL , model .getTaskType ()).toString (),
152- key -> new ModelStats (Metadata .ALL , model .getTaskType ())
159+ key -> createEmptyStats (Metadata .ALL , model .getTaskType ())
153160 ).add ();
154161 }
155162
@@ -162,6 +169,19 @@ private static void addStatsByServiceAndTask(
162169 addTopLevelStatsByTask (inferenceFieldsByIndexServiceAndTask , endpointStats );
163170 }
164171
172+ private static ModelStats createEmptyStats (ModelConfigurations model ) {
173+ return createEmptyStats (model .getService (), model .getTaskType ());
174+ }
175+
176+ private static ModelStats createEmptyStats (String service , TaskType taskType ) {
177+ return new ModelStats (
178+ service ,
179+ taskType ,
180+ 0 ,
181+ TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT .contains (taskType ) ? new SemanticTextStats () : null
182+ );
183+ }
184+
165185 private static void addTopLevelStatsByTask (
166186 Map <ServiceAndTaskType , Map <String , List <InferenceFieldMetadata >>> inferenceFieldsByIndexServiceAndTask ,
167187 Map <String , ModelStats > endpointStats
@@ -172,9 +192,9 @@ private static void addTopLevelStatsByTask(
172192 }
173193 ModelStats allStatsForTaskType = endpointStats .computeIfAbsent (
174194 new ServiceAndTaskType (Metadata .ALL , taskType ).toString (),
175- key -> new ModelStats (Metadata .ALL , taskType )
195+ key -> createEmptyStats (Metadata .ALL , taskType )
176196 );
177- if (taskType . isCompatibleWithSemanticText ( )) {
197+ if (TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT . contains ( taskType )) {
178198 Map <String , List <InferenceFieldMetadata >> inferenceFieldsByIndex = inferenceFieldsByIndexServiceAndTask .entrySet ()
179199 .stream ()
180200 .filter (e -> e .getKey ().taskType == taskType )
@@ -226,7 +246,12 @@ private void addStatsForDefaultModelsCompatibleWithSemanticText(
226246 // Now that we have all inference fields for this service and task type, we want to keep only the ones that
227247 // reference the current default model.
228248 fieldsByIndex = filterFields (fieldsByIndex , f -> statKey .modelId .equals (endpointIdToModelId .get (f .getInferenceId ())));
229- ModelStats stats = new ModelStats (statKey .toString (), statKey .taskType , defaultModelStatsKeyToEndpointCount .getValue ());
249+ ModelStats stats = new ModelStats (
250+ statKey .toString (),
251+ statKey .taskType ,
252+ defaultModelStatsKeyToEndpointCount .getValue (),
253+ new SemanticTextStats ()
254+ );
230255 addSemanticTextStats (fieldsByIndex , stats );
231256 endpointStats .put (statKey .toString (), stats );
232257 }
@@ -239,7 +264,7 @@ private Map<DefaultModelStatsKey, Long> createStatsKeysWithEndpointCountsForDefa
239264 // Note that endpoints could have a null model id, in which case we don't consider them default as this
240265 // may only happen for external services.
241266 Set <String > modelIds = endpoints .stream ()
242- .filter (endpoint -> endpoint .getTaskType (). isCompatibleWithSemanticText ( ))
267+ .filter (endpoint -> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT . contains ( endpoint .getTaskType ()))
243268 .filter (endpoint -> modelRegistry .containsDefaultConfigId (endpoint .getInferenceEntityId ()))
244269 .filter (endpoint -> endpoint .getServiceSettings ().modelId () != null )
245270 .map (endpoint -> stripLinuxSuffix (endpoint .getServiceSettings ().modelId ()))
0 commit comments