-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] add ML package loader module (#95207)
This PR introduces a new x-pack module for downloading and installing prepackaged models. The module is necessary, because we have to bypass the java security manager in order to open an http connections and/or access a file. The module limits this to only the minimum number of classes. 2 internal actions are introduced to get metadata of a model and the model itself. Core changes have been implemented in: #95175
- Loading branch information
Hendrik Muhs
committed
Apr 13, 2023
1 parent
107bb18
commit 9114965
Showing
7 changed files
with
566 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import org.apache.tools.ant.taskdefs.condition.Os | ||
import org.elasticsearch.gradle.OS | ||
|
||
apply plugin: 'elasticsearch.internal-es-plugin' | ||
|
||
esplugin { | ||
name 'ml-package-loader' | ||
description 'Loader for prepackaged Machine Learning Models from Elastic' | ||
classname 'org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader' | ||
extendedPlugins = ['x-pack-core'] | ||
} | ||
|
||
dependencies { | ||
implementation project(path: ':libs:elasticsearch-logging') | ||
compileOnly project(":server") | ||
compileOnly project(path: xpackModule('core')) | ||
testImplementation(testArtifact(project(xpackModule('core')))) | ||
} |
55 changes: 55 additions & 0 deletions
55
.../src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
package org.elasticsearch.xpack.ml.packageloader; | ||
|
||
import org.elasticsearch.action.ActionRequest; | ||
import org.elasticsearch.action.ActionResponse; | ||
import org.elasticsearch.common.settings.Setting; | ||
import org.elasticsearch.common.settings.Settings; | ||
import org.elasticsearch.plugins.ActionPlugin; | ||
import org.elasticsearch.plugins.Plugin; | ||
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction; | ||
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; | ||
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction; | ||
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
|
||
public class MachineLearningPackageLoader extends Plugin implements ActionPlugin { | ||
|
||
private final Settings settings; | ||
|
||
public static final String DEFAULT_ML_MODELS_REPOSITORY = "https://ml-models.elastic.co"; | ||
public static final Setting<String> MODEL_REPOSITORY = Setting.simpleString( | ||
"xpack.ml.model_repository", | ||
DEFAULT_ML_MODELS_REPOSITORY, | ||
Setting.Property.NodeScope, | ||
Setting.Property.Dynamic | ||
); | ||
|
||
// re-using thread pool setup by the ml plugin | ||
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility"; | ||
|
||
public MachineLearningPackageLoader(Settings settings) { | ||
this.settings = settings; | ||
} | ||
|
||
@Override | ||
public List<Setting<?>> getSettings() { | ||
return List.of(MODEL_REPOSITORY); | ||
} | ||
|
||
@Override | ||
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() { | ||
// all internal, no rest endpoint | ||
return Arrays.asList( | ||
new ActionHandler<>(GetTrainedModelPackageConfigAction.INSTANCE, TransportGetTrainedModelPackageConfigAction.class), | ||
new ActionHandler<>(LoadTrainedModelPackageAction.INSTANCE, TransportLoadTrainedModelPackage.class) | ||
); | ||
} | ||
} |
195 changes: 195 additions & 0 deletions
195
...oader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.packageloader.action; | ||
|
||
import org.elasticsearch.ElasticsearchStatusException; | ||
import org.elasticsearch.ResourceNotFoundException; | ||
import org.elasticsearch.SpecialPermission; | ||
import org.elasticsearch.common.bytes.BytesArray; | ||
import org.elasticsearch.common.hash.MessageDigests; | ||
import org.elasticsearch.common.io.Streams; | ||
import org.elasticsearch.common.unit.ByteSizeUnit; | ||
import org.elasticsearch.common.unit.ByteSizeValue; | ||
import org.elasticsearch.core.SuppressForbidden; | ||
import org.elasticsearch.core.Tuple; | ||
import org.elasticsearch.rest.RestStatus; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
import org.elasticsearch.xcontent.XContentType; | ||
|
||
import java.io.File; | ||
import java.io.IOException; | ||
import java.io.InputStream; | ||
import java.io.UncheckedIOException; | ||
import java.net.HttpURLConnection; | ||
import java.net.URI; | ||
import java.nio.file.Files; | ||
import java.security.AccessController; | ||
import java.security.MessageDigest; | ||
import java.security.PrivilegedAction; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import static java.net.HttpURLConnection.HTTP_MOVED_PERM; | ||
import static java.net.HttpURLConnection.HTTP_MOVED_TEMP; | ||
import static java.net.HttpURLConnection.HTTP_NOT_FOUND; | ||
import static java.net.HttpURLConnection.HTTP_OK; | ||
import static java.net.HttpURLConnection.HTTP_SEE_OTHER; | ||
|
||
/** | ||
* Helper class for downloading pre-trained Elastic models, available on ml-models.elastic.co or as file | ||
*/ | ||
final class ModelLoaderUtils { | ||
|
||
public static String METADATA_FILE_EXTENSION = ".metadata.json"; | ||
public static String MODEL_FILE_EXTENSION = ".pt"; | ||
|
||
private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(10, ByteSizeUnit.MB); | ||
private static final String VOCABULARY = "vocabulary"; | ||
private static final String MERGES = "merges"; | ||
|
||
static class InputStreamChunker { | ||
|
||
private final InputStream inputStream; | ||
private final MessageDigest digestSha256 = MessageDigests.sha256(); | ||
private final int chunkSize; | ||
|
||
InputStreamChunker(InputStream inputStream, int chunkSize) { | ||
this.inputStream = inputStream; | ||
this.chunkSize = chunkSize; | ||
} | ||
|
||
public BytesArray next() throws IOException { | ||
int bytesRead = 0; | ||
byte[] buf = new byte[chunkSize]; | ||
|
||
while (bytesRead < chunkSize) { | ||
int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead); | ||
// EOF?? | ||
if (read == -1) { | ||
break; | ||
} | ||
bytesRead += read; | ||
} | ||
digestSha256.update(buf, 0, bytesRead); | ||
|
||
return new BytesArray(buf, 0, bytesRead); | ||
} | ||
|
||
public String getSha256() { | ||
return MessageDigests.toHexString(digestSha256.digest()); | ||
} | ||
|
||
} | ||
|
||
static InputStream getInputStreamFromModelRepository(URI uri) throws IOException { | ||
String scheme = uri.getScheme().toLowerCase(Locale.ROOT); | ||
switch (scheme) { | ||
case "http": | ||
case "https": | ||
return getHttpOrHttpsInputStream(uri); | ||
case "file": | ||
return getFileInputStream(uri); | ||
default: | ||
throw new IllegalArgumentException("unsupported scheme"); | ||
} | ||
} | ||
|
||
public static Tuple<List<String>, List<String>> loadVocabulary(URI uri) { | ||
try { | ||
InputStream vocabInputStream = getInputStreamFromModelRepository(uri); | ||
|
||
if (uri.getPath().endsWith(".json")) { | ||
XContentParser sourceParser = XContentType.JSON.xContent() | ||
.createParser( | ||
XContentParserConfiguration.EMPTY, | ||
Streams.limitStream(vocabInputStream, VOCABULARY_SIZE_LIMIT.getBytes()) | ||
); | ||
Map<String, List<Object>> vocabAndMerges = sourceParser.map(HashMap::new, XContentParser::list); | ||
|
||
List<String> vocabulary = vocabAndMerges.containsKey(VOCABULARY) | ||
? vocabAndMerges.get(VOCABULARY).stream().map(Object::toString).collect(Collectors.toList()) | ||
: Collections.emptyList(); | ||
List<String> merges = vocabAndMerges.containsKey(MERGES) | ||
? vocabAndMerges.get(MERGES).stream().map(Object::toString).collect(Collectors.toList()) | ||
: Collections.emptyList(); | ||
|
||
return Tuple.tuple(vocabulary, merges); | ||
} | ||
|
||
throw new IllegalArgumentException("unknown format vocabulary file format"); | ||
} catch (Exception e) { | ||
throw new RuntimeException("Failed to load vocabulary file", e); | ||
} | ||
} | ||
|
||
private ModelLoaderUtils() {} | ||
|
||
@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") | ||
@SuppressForbidden(reason = "we need socket connection to download") | ||
private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException { | ||
|
||
SecurityManager sm = System.getSecurityManager(); | ||
if (sm != null) { | ||
sm.checkPermission(new SpecialPermission()); | ||
} | ||
|
||
PrivilegedAction<InputStream> privilegedHttpReader = () -> { | ||
try { | ||
HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection(); | ||
switch (conn.getResponseCode()) { | ||
case HTTP_OK: | ||
return conn.getInputStream(); | ||
case HTTP_MOVED_PERM: | ||
case HTTP_MOVED_TEMP: | ||
case HTTP_SEE_OTHER: | ||
throw new IllegalStateException("redirects aren't supported yet"); | ||
case HTTP_NOT_FOUND: | ||
throw new ResourceNotFoundException("{} not found", uri); | ||
default: | ||
int responseCode = conn.getResponseCode(); | ||
throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri); | ||
} | ||
} catch (IOException e) { | ||
throw new UncheckedIOException(e); | ||
} | ||
}; | ||
|
||
return AccessController.doPrivileged(privilegedHttpReader); | ||
} | ||
|
||
@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") | ||
@SuppressForbidden(reason = "we need load model data from a file") | ||
private static InputStream getFileInputStream(URI uri) { | ||
|
||
SecurityManager sm = System.getSecurityManager(); | ||
if (sm != null) { | ||
sm.checkPermission(new SpecialPermission()); | ||
} | ||
|
||
PrivilegedAction<InputStream> privilegedFileReader = () -> { | ||
File file = new File(uri); | ||
if (file.exists() == false) { | ||
throw new ResourceNotFoundException("{} not found", uri); | ||
} | ||
|
||
try { | ||
return Files.newInputStream(file.toPath()); | ||
} catch (IOException e) { | ||
throw new UncheckedIOException(e); | ||
} | ||
}; | ||
|
||
return AccessController.doPrivileged(privilegedFileReader); | ||
} | ||
|
||
} |
Oops, something went wrong.