diff --git a/docs/changelog/121559.yaml b/docs/changelog/121559.yaml new file mode 100644 index 0000000000000..e3870609a454f --- /dev/null +++ b/docs/changelog/121559.yaml @@ -0,0 +1,6 @@ +pr: 121559 +summary: Skip Usage stats update when ML is disabled +area: Machine Learning +type: bug +issues: + - 121532 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 9410c4e04b0df..7c1c57283fe98 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; @@ -36,6 +37,7 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -111,8 +113,11 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); + private final Settings settings; + public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { super(context); + this.settings = context.settings(); } // for testing @@ -121,6 +126,7 @@ public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFa Consumer> platformArch ) { super(context, platformArch); + this.settings = context.settings(); } @Override @@ -825,6 +831,12 @@ public void updateModelsWithDynamicFields(List models, ActionListener(); for (var model : models) { assert model instanceof ElasticsearchInternalModel; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 7b5dddb8b6bba..02af49d2aa467 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; @@ -47,12 +48,14 @@ import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.MachineLearningField; +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; @@ -68,11 +71,13 @@ import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.EnumSet; @@ -82,12 +87,14 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD; import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; @@ -102,6 +109,7 @@ import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class ElasticsearchInternalServiceTests extends ESTestCase { @@ -1698,6 +1706,67 @@ public void testGetConfiguration() throws Exception { } } + public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException { + var cs = mock(ClusterService.class); + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); + when(cs.getClusterSettings()).thenReturn(cSettings); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( + mock(), + threadPool, + cs, + Settings.builder().put("xpack.ml.enabled", false).build() + ); + try (var service = new ElasticsearchInternalService(context)) { + var models = List.of(mock(Model.class)); + var latch = new CountDownLatch(1); + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> { + latch.countDown(); + assertThat(r, Matchers.sameInstance(models)); + })); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testUpdateWithMlEnabled() throws IOException, InterruptedException { + var deploymentId = "deploymentId"; + var model = mock(ElasticsearchInternalModel.class); + when(model.mlNodeDeploymentId()).thenReturn(deploymentId); + + AssignmentStats stats = mock(); + when(stats.getDeploymentId()).thenReturn(deploymentId); + when(stats.getNumberOfAllocations()).thenReturn(3); + + var client = mock(Client.class); + doAnswer(ans -> { + QueryPage queryPage = new QueryPage<>(List.of(stats), 1, RESULTS_FIELD); + + GetDeploymentStatsAction.Response response = mock(); + when(response.getStats()).thenReturn(queryPage); + + ActionListener listener = ans.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + when(client.threadPool()).thenReturn(threadPool); + + var cs = mock(ClusterService.class); + var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES)); + when(cs.getClusterSettings()).thenReturn(cSettings); + var context = new InferenceServiceExtension.InferenceServiceFactoryContext( + client, + threadPool, + cs, + Settings.builder().put("xpack.ml.enabled", true).build() + ); + try (var service = new ElasticsearchInternalService(context)) { + List models = List.of(model); + var latch = new CountDownLatch(1); + service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown())); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + verify(model).updateNumAllocations(3); + } + } + private ElasticsearchInternalService createService(Client client) { var cs = mock(ClusterService.class); var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));