Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ public Collection<?> createComponents(PluginServices services) {

// Add binding for interface -> implementation
components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
components.add(calculator);

return components;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,55 +10,55 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING;
import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0)
public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {

private static final Integer RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS = 15;

public void setUp() throws Exception {
super.setUp();
}

public void testInitialClusterGrouping_Correct() {
public void testInitialClusterGrouping_Correct() throws Exception {
// Start with 2-5 nodes
var numNodes = randomIntBetween(2, 5);
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

RateLimitAssignment firstAssignment = null;
var firstCalculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
waitForRateLimitingAssignments(firstCalculator);

for (String nodeName : nodeNames) {
var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName);

// Check first node's assignments
if (firstAssignment == null) {
// Get assignment for a specific service (e.g., EIS)
firstAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);

assertNotNull(firstAssignment);
// Verify there are assignments for this service
assertFalse(firstAssignment.responsibleNodes().isEmpty());
} else {
// Verify other nodes see the same assignment
var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
assertEquals(firstAssignment, currentAssignment);
}
RateLimitAssignment firstAssignment = firstCalculator.getRateLimitAssignment(
ElasticInferenceService.NAME,
TaskType.SPARSE_EMBEDDING
);

// Verify that all other nodes land on the same assignment
for (String nodeName : nodeNames.subList(1, nodeNames.size())) {
var calculator = getCalculatorInstance(internalCluster(), nodeName);
waitForRateLimitingAssignments(calculator);
var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
assertEquals(firstAssignment, currentAssignment);
}
}

public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws IOException {
public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws Exception {
// Start with 3-5 nodes
var numNodes = randomIntBetween(3, 5);
var nodeNames = internalCluster().startNodes(numNodes);
Expand All @@ -77,7 +77,8 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
ensureStableCluster(currentNumberOfNodes);
}

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeLeftInCluster);
var calculator = getCalculatorInstance(internalCluster(), nodeLeftInCluster);
waitForRateLimitingAssignments(calculator);

Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();

Expand All @@ -93,13 +94,14 @@ public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws
}
}

public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
public void testGrouping_RespectsMaxNodesPerGroupingLimit() throws Exception {
// Start with more nodes possible per grouping
var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween(1, 3);
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
waitForRateLimitingAssignments(calculator);

Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();

Expand All @@ -111,13 +113,14 @@ public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
}
}

public void testInitialRateLimitsCalculation_Correct() throws IOException {
public void testInitialRateLimitsCalculation_Correct() throws Exception {
// Start with max nodes per grouping (=3)
int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
waitForRateLimitingAssignments(calculator);

Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();

Expand All @@ -129,7 +132,7 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {

if ((service instanceof SenderService senderService)) {
var sender = senderService.getSender();
if (sender instanceof HttpRequestSender httpSender) {
if (sender instanceof HttpRequestSender) {
var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING);

assertNotNull(assignment);
Expand All @@ -141,13 +144,14 @@ public void testInitialRateLimitsCalculation_Correct() throws IOException {
}
}

public void testRateLimits_Decrease_OnNodeJoin() {
public void testRateLimits_Decrease_OnNodeJoin() throws Exception {
// Start with 2 nodes
var initialNodes = 2;
var nodeNames = internalCluster().startNodes(initialNodes);
ensureStableCluster(initialNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
waitForRateLimitingAssignments(calculator);

for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
Expand All @@ -159,6 +163,7 @@ public void testRateLimits_Decrease_OnNodeJoin() {
// Add a new node
internalCluster().startNode();
ensureStableCluster(initialNodes + 1);
waitForRateLimitingAssignments(calculator);

// Get updated assignments
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
Expand All @@ -169,13 +174,14 @@ public void testRateLimits_Decrease_OnNodeJoin() {
}
}

public void testRateLimits_Increase_OnNodeLeave() throws IOException {
public void testRateLimits_Increase_OnNodeLeave() throws Exception {
// Start with max nodes per grouping (=3)
int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
var nodeNames = internalCluster().startNodes(numNodes);
ensureStableCluster(numNodes);

var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
var calculator = getCalculatorInstance(internalCluster(), nodeNames.get(0));
waitForRateLimitingAssignments(calculator);

for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
Expand All @@ -188,6 +194,7 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
var nodeToRemove = nodeNames.get(numNodes - 1);
internalCluster().stopNode(nodeToRemove);
ensureStableCluster(numNodes - 1);
waitForRateLimitingAssignments(calculator);

// Get updated assignments
var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
Expand All @@ -202,4 +209,33 @@ public void testRateLimits_Increase_OnNodeLeave() throws IOException {
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
}

private InferenceServiceNodeLocalRateLimitCalculator getCalculatorInstance(InternalTestCluster internalTestCluster, String nodeName) {
InferenceServiceRateLimitCalculator calculatorInstance = internalTestCluster.getInstance(
InferenceServiceRateLimitCalculator.class,
nodeName
);
assertThat(
"["
+ InferenceServiceNodeLocalRateLimitCalculatorTests.class.getName()
+ "] should use ["
+ InferenceServiceNodeLocalRateLimitCalculator.class.getName()
+ "] as implementation for ["
+ InferenceServiceRateLimitCalculator.class.getName()
+ "]. Provided implementation was ["
+ calculatorInstance.getClass().getName()
+ "].",
calculatorInstance,
instanceOf(InferenceServiceNodeLocalRateLimitCalculator.class)
);
return (InferenceServiceNodeLocalRateLimitCalculator) calculatorInstance;
}

private void waitForRateLimitingAssignments(InferenceServiceNodeLocalRateLimitCalculator calculator) throws Exception {
assertBusy(() -> {
var assignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
assertNotNull(assignment);
assertFalse(assignment.responsibleNodes().isEmpty());
}, RATE_LIMIT_ASSIGNMENT_MAX_WAIT_TIME_IN_SECONDS, TimeUnit.SECONDS);
}
}