Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/121559.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 121559
summary: Skip Usage stats update when ML is disabled
area: Machine Learning
type: bug
issues:
- 121532
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
Expand All @@ -36,6 +37,7 @@
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
Expand Down Expand Up @@ -111,8 +113,11 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);

private final Settings settings;

public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
super(context);
this.settings = context.settings();
}

// for testing
Expand All @@ -121,6 +126,7 @@ public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFa
Consumer<ActionListener<PreferredModelVariant>> platformArch
) {
super(context, platformArch);
this.settings = context.settings();
}

@Override
Expand Down Expand Up @@ -825,6 +831,12 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
return;
}

// if ML is disabled, do not update Deployment Stats (there won't be changes)
if (XPackSettings.MACHINE_LEARNING_ENABLED.get(settings) == false) {
listener.onResponse(models);
return;
}

var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
for (var model : models) {
assert model instanceof ElasticsearchInternalModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -47,12 +48,14 @@
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests;
Expand All @@ -68,11 +71,13 @@
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
Expand All @@ -82,12 +87,14 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response.RESULTS_FIELD;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
Expand All @@ -102,6 +109,7 @@
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class ElasticsearchInternalServiceTests extends ESTestCase {
Expand Down Expand Up @@ -1698,6 +1706,67 @@ public void testGetConfiguration() throws Exception {
}
}

public void testUpdateWithoutMlEnabled() throws IOException, InterruptedException {
var cs = mock(ClusterService.class);
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
when(cs.getClusterSettings()).thenReturn(cSettings);
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(
mock(),
threadPool,
cs,
Settings.builder().put("xpack.ml.enabled", false).build()
);
try (var service = new ElasticsearchInternalService(context)) {
var models = List.of(mock(Model.class));
var latch = new CountDownLatch(1);
service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> {
latch.countDown();
assertThat(r, Matchers.sameInstance(models));
}));
assertTrue(latch.await(30, TimeUnit.SECONDS));
}
}

public void testUpdateWithMlEnabled() throws IOException, InterruptedException {
var deploymentId = "deploymentId";
var model = mock(ElasticsearchInternalModel.class);
when(model.mlNodeDeploymentId()).thenReturn(deploymentId);

AssignmentStats stats = mock();
when(stats.getDeploymentId()).thenReturn(deploymentId);
when(stats.getNumberOfAllocations()).thenReturn(3);

var client = mock(Client.class);
doAnswer(ans -> {
QueryPage<AssignmentStats> queryPage = new QueryPage<>(List.of(stats), 1, RESULTS_FIELD);

GetDeploymentStatsAction.Response response = mock();
when(response.getStats()).thenReturn(queryPage);

ActionListener<GetDeploymentStatsAction.Response> listener = ans.getArgument(2);
listener.onResponse(response);
return null;
}).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
when(client.threadPool()).thenReturn(threadPool);

var cs = mock(ClusterService.class);
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
when(cs.getClusterSettings()).thenReturn(cSettings);
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(
client,
threadPool,
cs,
Settings.builder().put("xpack.ml.enabled", true).build()
);
try (var service = new ElasticsearchInternalService(context)) {
List<Model> models = List.of(model);
var latch = new CountDownLatch(1);
service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown()));
assertTrue(latch.await(30, TimeUnit.SECONDS));
verify(model).updateNumAllocations(3);
}
}

private ElasticsearchInternalService createService(Client client) {
var cs = mock(ClusterService.class);
var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
Expand Down