diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java index bb6f05100a832..e020749b3e71e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java @@ -10,9 +10,12 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import static org.hamcrest.Matchers.equalTo; + public class ModelStatsTests extends AbstractWireSerializingTestCase { @Override @@ -27,12 +30,32 @@ protected ModelStats createTestInstance() { @Override protected ModelStats mutateInstance(ModelStats modelStats) throws IOException { - ModelStats newModelStats = new ModelStats(modelStats); - newModelStats.add(); - return newModelStats; + String service = modelStats.service(); + TaskType taskType = modelStats.taskType(); + long count = modelStats.count(); + return switch (randomInt(2)) { + case 0 -> new ModelStats(randomValueOtherThan(service, ESTestCase::randomIdentifier), taskType, count); + case 1 -> new ModelStats(service, randomValueOtherThan(taskType, () -> randomFrom(TaskType.values())), count); + case 2 -> new ModelStats(service, taskType, randomValueOtherThan(count, ESTestCase::randomLong)); + default -> throw new IllegalArgumentException(); + }; + } + + public void testAdd() { + ModelStats stats = new ModelStats("test_service", randomFrom(TaskType.values())); + assertThat(stats.count(), equalTo(0L)); + + stats.add(); + assertThat(stats.count(), equalTo(1L)); + + int iterations = randomIntBetween(1, 10); + for (int i = 0; i < iterations; i++) { + stats.add(); + } + assertThat(stats.count(), equalTo(1L + iterations)); } public static ModelStats createRandomInstance() { - return new ModelStats(randomIdentifier(), TaskType.values()[randomInt(TaskType.values().length - 1)], randomInt(10)); + return new ModelStats(randomIdentifier(), randomFrom(TaskType.values()), randomLong()); } }