Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,16 @@
@LuceneTestCase.SuppressCodecs("*") // use our custom codec
public class GPUIndexIT extends ESIntegTestCase {

public static class TestGPUPlugin extends GPUPlugin {
@Override
protected boolean isGpuIndexingFeatureAllowed() {
return true;
}
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(GPUPlugin.class);
return List.of(TestGPUPlugin.class);
}

@BeforeClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.junit.After;

import java.util.Collection;
import java.util.List;

import static org.elasticsearch.xpack.gpu.TestVectorsFormatUtils.randomGPUSupportedSimilarity;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.startsWith;

public class GPUPluginInitializationWithGPUIT extends ESIntegTestCase {

static {
TestCuVSServiceProvider.mockedGPUInfoProvider = SUPPORTEp -> new TestCuVSServiceProvider.TestGPUInfoProvider(
TestCuVSServiceProvider.mockedGPUInfoProvider = p -> new TestCuVSServiceProvider.TestGPUInfoProvider(
List.of(
new GPUInfo(
0,
Expand All @@ -44,15 +46,34 @@ public class GPUPluginInitializationWithGPUIT extends ESIntegTestCase {
);
}

private static boolean isGpuIndexingFeatureAllowed = true;

public static class TestGPUPlugin extends GPUPlugin {

public TestGPUPlugin() {
super();
}

@Override
protected boolean isGpuIndexingFeatureAllowed() {
return GPUPluginInitializationWithGPUIT.isGpuIndexingFeatureAllowed;
}
}

@After
public void reset() {
isGpuIndexingFeatureAllowed = true;
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(GPUPlugin.class);
return List.of(TestGPUPlugin.class);
}

public void testFFOff() {
assumeFalse("GPU_FORMAT feature flag disabled", GPUPlugin.GPU_FORMAT.isEnabled());

GPUPlugin gpuPlugin = internalCluster().getInstance(GPUPlugin.class);
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

var format = vectorsFormatProvider.getKnnVectorsFormat(null, null, null);
Expand All @@ -74,7 +95,7 @@ public void testFFOffIndexSettingNotSupported() {
public void testFFOffGPUFormatNull() {
assumeFalse("GPU_FORMAT feature flag disabled", GPUPlugin.GPU_FORMAT.isEnabled());

GPUPlugin gpuPlugin = internalCluster().getInstance(GPUPlugin.class);
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.EMPTY);
Expand All @@ -89,10 +110,10 @@ public void testFFOffGPUFormatNull() {
assertNull(format);
}

public void testIndexSettingOnIndexTypeSupportedGPUSupported() {
public void testIndexSettingOnIndexAllSupported() {
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());

GPUPlugin gpuPlugin = internalCluster().getInstance(GPUPlugin.class);
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.builder().put(GPUPlugin.VECTORS_INDEXING_USE_GPU_SETTING.getKey(), GPUPlugin.GpuMode.TRUE).build());
Expand All @@ -110,7 +131,7 @@ public void testIndexSettingOnIndexTypeSupportedGPUSupported() {
public void testIndexSettingOnIndexTypeNotSupportedThrows() {
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());

GPUPlugin gpuPlugin = internalCluster().getInstance(GPUPlugin.class);
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.builder().put(GPUPlugin.VECTORS_INDEXING_USE_GPU_SETTING.getKey(), GPUPlugin.GpuMode.TRUE).build());
Expand All @@ -124,10 +145,31 @@ public void testIndexSettingOnIndexTypeNotSupportedThrows() {
assertThat(ex.getMessage(), startsWith("[index.vectors.indexing.use_gpu] doesn't support [index_options.type] of"));
}

public void testIndexSettingAutoIndexTypeSupportedGPUSupported() {
public void testIndexSettingOnIndexLicenseNotSupportedThrows() {
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());
isGpuIndexingFeatureAllowed = false;

GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.builder().put(GPUPlugin.VECTORS_INDEXING_USE_GPU_SETTING.getKey(), GPUPlugin.GpuMode.TRUE).build());
IndexSettings settings = getIndexSettings();
final var indexOptions = DenseVectorFieldTypeTests.randomGpuSupportedIndexOptions();

var ex = expectThrows(
IllegalArgumentException.class,
() -> vectorsFormatProvider.getKnnVectorsFormat(settings, indexOptions, randomGPUSupportedSimilarity(indexOptions.getType()))
);
assertThat(
ex.getMessage(),
equalTo("[index.vectors.indexing.use_gpu] was set to [true], but GPU indexing is a [ENTERPRISE] level feature")
);
}

public void testIndexSettingAutoAllSupported() {
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());

GPUPlugin gpuPlugin = internalCluster().getInstance(GPUPlugin.class);
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.builder().put(GPUPlugin.VECTORS_INDEXING_USE_GPU_SETTING.getKey(), GPUPlugin.GpuMode.AUTO).build());
Expand All @@ -142,10 +184,29 @@ public void testIndexSettingAutoIndexTypeSupportedGPUSupported() {
assertNotNull(format);
}

public void testIndexSettingAutoLicenseNotSupported() {
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());
isGpuIndexingFeatureAllowed = false;

GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.builder().put(GPUPlugin.VECTORS_INDEXING_USE_GPU_SETTING.getKey(), GPUPlugin.GpuMode.AUTO).build());
IndexSettings settings = getIndexSettings();
final var indexOptions = DenseVectorFieldTypeTests.randomGpuSupportedIndexOptions();

var format = vectorsFormatProvider.getKnnVectorsFormat(
settings,
indexOptions,
randomGPUSupportedSimilarity(indexOptions.getType())
);
assertNull(format);
}

public void testIndexSettingAutoIndexTypeNotSupported() {
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());

GPUPlugin gpuPlugin = internalCluster().getInstance(GPUPlugin.class);
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.builder().put(GPUPlugin.VECTORS_INDEXING_USE_GPU_SETTING.getKey(), GPUPlugin.GpuMode.AUTO).build());
Expand All @@ -163,7 +224,7 @@ public void testIndexSettingAutoIndexTypeNotSupported() {
public void testIndexSettingOff() {
assumeTrue("GPU_FORMAT feature flag enabled", GPUPlugin.GPU_FORMAT.isEnabled());

GPUPlugin gpuPlugin = internalCluster().getInstance(GPUPlugin.class);
GPUPlugin gpuPlugin = internalCluster().getInstance(TestGPUPlugin.class);
VectorsFormatProvider vectorsFormatProvider = gpuPlugin.getVectorsFormatProvider();

createIndex("index1", Settings.builder().put(GPUPlugin.VECTORS_INDEXING_USE_GPU_SETTING.getKey(), GPUPlugin.GpuMode.FALSE).build());
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/gpu/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
requires org.elasticsearch.server;
requires org.elasticsearch.base;
requires org.elasticsearch.gpu;
requires org.elasticsearch.xcore;

provides org.elasticsearch.features.FeatureSpecification with org.elasticsearch.xpack.gpu.GPUFeatures;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
import org.elasticsearch.gpu.codec.ES92GpuHnswVectorsFormat;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.VectorsFormatProvider;
import org.elasticsearch.license.License;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.internal.InternalVectorFormatProviderPlugin;
import org.elasticsearch.xpack.core.XPackPlugin;

import java.util.List;

public class GPUPlugin extends Plugin implements InternalVectorFormatProviderPlugin {

public static final FeatureFlag GPU_FORMAT = new FeatureFlag("gpu_vectors_indexing");

private static final License.OperationMode MINIMUM_ALLOWED_LICENSE = License.OperationMode.ENTERPRISE;

/**
* An enum for the tri-state value of the `index.vectors.indexing.use_gpu` setting.
*/
Expand Down Expand Up @@ -59,6 +63,12 @@ public List<Setting<?>> getSettings() {
}
}

// Allow tests to override the license state
protected boolean isGpuIndexingFeatureAllowed() {
var licenseState = XPackPlugin.getSharedLicenseState();
return licenseState != null && licenseState.isAllowedByLicense(MINIMUM_ALLOWED_LICENSE);
}

@Override
public VectorsFormatProvider getVectorsFormatProvider() {
return (indexSettings, indexOptions, similarity) -> {
Expand All @@ -75,9 +85,19 @@ public VectorsFormatProvider getVectorsFormatProvider() {
"[index.vectors.indexing.use_gpu] was set to [true], but GPU resources are not accessible on the node."
);
}
if (isGpuIndexingFeatureAllowed() == false) {
throw new IllegalArgumentException(
"[index.vectors.indexing.use_gpu] was set to [true], but GPU indexing is a ["
+ MINIMUM_ALLOWED_LICENSE
+ "] level feature"
);
}
return getVectorsFormat(indexOptions, similarity);
}
if (gpuMode == GpuMode.AUTO && vectorIndexTypeSupported(indexOptions.getType()) && GPUSupport.isSupported()) {
if (gpuMode == GpuMode.AUTO
&& vectorIndexTypeSupported(indexOptions.getType())
&& GPUSupport.isSupported()
&& isGpuIndexingFeatureAllowed()) {
return getVectorsFormat(indexOptions, similarity);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ public static void setup() {

@Override
protected Collection<Plugin> getPlugins() {
var plugin = new GPUPlugin();
var plugin = new GPUPlugin() {
@Override
protected boolean isGpuIndexingFeatureAllowed() {
return true;
}
};
return Collections.singletonList(plugin);
}

Expand Down