Skip to content

Commit

Permalink
[ML] add ML package loader module (#95207)
Browse files Browse the repository at this point in the history
This PR introduces a new x-pack module for downloading and installing prepackaged models. The module is necessary, because we have to bypass the java security manager in order to open an http connections and/or access a file. The module limits this to only the minimum number of classes. 2 internal actions are introduced to get metadata of a model and the model itself.

Core changes have been implemented in: #95175
  • Loading branch information
Hendrik Muhs committed Apr 13, 2023
1 parent 107bb18 commit 9114965
Show file tree
Hide file tree
Showing 7 changed files with 566 additions and 0 deletions.
25 changes: 25 additions & 0 deletions x-pack/plugin/ml-package-loader/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import org.apache.tools.ant.taskdefs.condition.Os
import org.elasticsearch.gradle.OS

apply plugin: 'elasticsearch.internal-es-plugin'

esplugin {
name 'ml-package-loader'
description 'Loader for prepackaged Machine Learning Models from Elastic'
classname 'org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader'
extendedPlugins = ['x-pack-core']
}

dependencies {
implementation project(path: ':libs:elasticsearch-logging')
compileOnly project(":server")
compileOnly project(path: xpackModule('core'))
testImplementation(testArtifact(project(xpackModule('core'))))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.ml.packageloader;

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;

import java.util.Arrays;
import java.util.List;

public class MachineLearningPackageLoader extends Plugin implements ActionPlugin {

private final Settings settings;

public static final String DEFAULT_ML_MODELS_REPOSITORY = "https://ml-models.elastic.co";
public static final Setting<String> MODEL_REPOSITORY = Setting.simpleString(
"xpack.ml.model_repository",
DEFAULT_ML_MODELS_REPOSITORY,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

// re-using thread pool setup by the ml plugin
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";

public MachineLearningPackageLoader(Settings settings) {
this.settings = settings;
}

@Override
public List<Setting<?>> getSettings() {
return List.of(MODEL_REPOSITORY);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
// all internal, no rest endpoint
return Arrays.asList(
new ActionHandler<>(GetTrainedModelPackageConfigAction.INSTANCE, TransportGetTrainedModelPackageConfigAction.class),
new ActionHandler<>(LoadTrainedModelPackageAction.INSTANCE, TransportLoadTrainedModelPackage.class)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.packageloader.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.HttpURLConnection;
import java.net.URI;
import java.nio.file.Files;
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.PrivilegedAction;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import static java.net.HttpURLConnection.HTTP_MOVED_PERM;
import static java.net.HttpURLConnection.HTTP_MOVED_TEMP;
import static java.net.HttpURLConnection.HTTP_NOT_FOUND;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_SEE_OTHER;

/**
* Helper class for downloading pre-trained Elastic models, available on ml-models.elastic.co or as file
*/
final class ModelLoaderUtils {

public static String METADATA_FILE_EXTENSION = ".metadata.json";
public static String MODEL_FILE_EXTENSION = ".pt";

private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(10, ByteSizeUnit.MB);
private static final String VOCABULARY = "vocabulary";
private static final String MERGES = "merges";

static class InputStreamChunker {

private final InputStream inputStream;
private final MessageDigest digestSha256 = MessageDigests.sha256();
private final int chunkSize;

InputStreamChunker(InputStream inputStream, int chunkSize) {
this.inputStream = inputStream;
this.chunkSize = chunkSize;
}

public BytesArray next() throws IOException {
int bytesRead = 0;
byte[] buf = new byte[chunkSize];

while (bytesRead < chunkSize) {
int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead);
// EOF??
if (read == -1) {
break;
}
bytesRead += read;
}
digestSha256.update(buf, 0, bytesRead);

return new BytesArray(buf, 0, bytesRead);
}

public String getSha256() {
return MessageDigests.toHexString(digestSha256.digest());
}

}

static InputStream getInputStreamFromModelRepository(URI uri) throws IOException {
String scheme = uri.getScheme().toLowerCase(Locale.ROOT);
switch (scheme) {
case "http":
case "https":
return getHttpOrHttpsInputStream(uri);
case "file":
return getFileInputStream(uri);
default:
throw new IllegalArgumentException("unsupported scheme");
}
}

public static Tuple<List<String>, List<String>> loadVocabulary(URI uri) {
try {
InputStream vocabInputStream = getInputStreamFromModelRepository(uri);

if (uri.getPath().endsWith(".json")) {
XContentParser sourceParser = XContentType.JSON.xContent()
.createParser(
XContentParserConfiguration.EMPTY,
Streams.limitStream(vocabInputStream, VOCABULARY_SIZE_LIMIT.getBytes())
);
Map<String, List<Object>> vocabAndMerges = sourceParser.map(HashMap::new, XContentParser::list);

List<String> vocabulary = vocabAndMerges.containsKey(VOCABULARY)
? vocabAndMerges.get(VOCABULARY).stream().map(Object::toString).collect(Collectors.toList())
: Collections.emptyList();
List<String> merges = vocabAndMerges.containsKey(MERGES)
? vocabAndMerges.get(MERGES).stream().map(Object::toString).collect(Collectors.toList())
: Collections.emptyList();

return Tuple.tuple(vocabulary, merges);
}

throw new IllegalArgumentException("unknown format vocabulary file format");
} catch (Exception e) {
throw new RuntimeException("Failed to load vocabulary file", e);
}
}

private ModelLoaderUtils() {}

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need socket connection to download")
private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException {

SecurityManager sm = System.getSecurityManager();
if (sm != null) {
sm.checkPermission(new SpecialPermission());
}

PrivilegedAction<InputStream> privilegedHttpReader = () -> {
try {
HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection();
switch (conn.getResponseCode()) {
case HTTP_OK:
return conn.getInputStream();
case HTTP_MOVED_PERM:
case HTTP_MOVED_TEMP:
case HTTP_SEE_OTHER:
throw new IllegalStateException("redirects aren't supported yet");
case HTTP_NOT_FOUND:
throw new ResourceNotFoundException("{} not found", uri);
default:
int responseCode = conn.getResponseCode();
throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};

return AccessController.doPrivileged(privilegedHttpReader);
}

@SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ")
@SuppressForbidden(reason = "we need load model data from a file")
private static InputStream getFileInputStream(URI uri) {

SecurityManager sm = System.getSecurityManager();
if (sm != null) {
sm.checkPermission(new SpecialPermission());
}

PrivilegedAction<InputStream> privilegedFileReader = () -> {
File file = new File(uri);
if (file.exists() == false) {
throw new ResourceNotFoundException("{} not found", uri);
}

try {
return Files.newInputStream(file.toPath());
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};

return AccessController.doPrivileged(privilegedFileReader);
}

}

0 comments on commit 9114965

Please sign in to comment.