Skip to content

Commit

Permalink
fix memory CB bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed May 22, 2024
1 parent 99e75aa commit ecc9928
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
Expand Down Expand Up @@ -177,6 +178,8 @@ public void onResponse(MLModel mlModel) {
);
} else if (e instanceof MLResourceNotFoundException) {
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
} else if (e instanceof MLLimitExceededException) {
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE));
} else {
wrappedListener
.onFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ public Short getThreshold() {

@Override
public boolean isOpen() {
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,13 @@ public void dispatchTask(
if (clusterService.localNode().getId().equals(node.getId())) {
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
request.setDispatchTask(false);
executeTask(request, listener);
run(
// This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here
functionName,
request,
transportService,
listener
);
} else {
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
request.setDispatchTask(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
Expand Down Expand Up @@ -235,6 +236,28 @@ public void testPrediction_MLResourceNotFoundException() {
assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage());
}

public void testPrediction_MLLimitExceededException() {
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
listener.onFailure(new MLLimitExceededException("Memory Circuit Breaker is open, please check your resources!"));
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

doAnswer(invocation -> {
((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null);
return null;
}).when(mlPredictTaskRunner).run(any(), any(), any(), any());

transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener);

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage());
}

public void testValidateInputSchemaSuccess() {
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
settingsService.applySettings(newSettingsBuilder.build());
Assert.assertFalse(breaker.isOpen());
}

@Test
public void testIsOpen_DisableMemoryCB() {
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
when(clusterService.getClusterSettings()).thenReturn(settingsService);

CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);

when(mem.getHeapUsedPercent()).thenReturn((short) 90);
Assert.assertTrue(breaker.isOpen());

when(mem.getHeapUsedPercent()).thenReturn((short) 100);
Settings.Builder newSettingsBuilder = Settings.builder();
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
settingsService.applySettings(newSettingsBuilder.build());
Assert.assertFalse(breaker.isOpen());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.ClusterManagerMetrics;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -64,13 +65,16 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase {
@Mock
private TokenBucket rateLimiter;

@Mock
ClusterManagerMetrics clusterManagerMetrics;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
maxMonitoringRequests = 10;
settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), maxMonitoringRequests).build();
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MONITORING_REQUEST_COUNT);
clusterService = spy(new ClusterService(settings, clusterSettings, null));
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics));

when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
cacheHelper = new MLModelCacheHelper(clusterService, settings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterManagerMetrics;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -177,7 +178,7 @@ public class MLModelManagerTests extends OpenSearchTestCase {
private ScriptService scriptService;

@Mock
private MLTask pretrainedMLTask;
ClusterManagerMetrics clusterManagerMetrics;

@Before
public void setup() throws URISyntaxException {
Expand All @@ -196,7 +197,7 @@ public void setup() throws URISyntaxException {
ML_COMMONS_MONITORING_REQUEST_COUNT,
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE
);
clusterService = spy(new ClusterService(settings, clusterSettings, null));
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics));
xContentRegistry = NamedXContentRegistry.EMPTY;

modelName = "model_name1";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterManagerMetrics;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
Expand All @@ -48,7 +49,6 @@
import org.opensearch.ml.stats.suppliers.CounterSupplier;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {

Expand All @@ -70,14 +70,14 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
@Mock
MLCircuitBreakerService mlCircuitBreakerService;

@Mock
TransportService transportService;

@Mock
ActionListener<MLExecuteTaskResponse> listener;
@Mock
DiscoveryNodeHelper nodeHelper;

@Mock
ClusterManagerMetrics clusterManagerMetrics;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -115,7 +115,7 @@ public void setup() {
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE,
ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL
);
clusterService = spy(new ClusterService(settings, clusterSettings, null));
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

Map<Enum, MLStat<?>> stats = new ConcurrentHashMap<>();
Expand Down

0 comments on commit ecc9928

Please sign in to comment.