Skip to content

Commit

Permalink
Merge pull request #65 from ruivieira/RHOAIENG-4963-b
Browse files Browse the repository at this point in the history
RHOAIENG-4963: ModelMesh should support TLS in payload processors
  • Loading branch information
openshift-merge-bot[bot] committed Aug 1, 2024
2 parents cf3fcf6 + ed8161a commit 2405dba
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 16 deletions.
47 changes: 44 additions & 3 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,18 @@
import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList;

import javax.annotation.concurrent.GuardedBy;
import java.io.File;
import java.io.InterruptedIOException;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import java.io.*;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.channels.ClosedByInterruptException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;
Expand Down Expand Up @@ -428,10 +431,38 @@ public abstract class ModelMesh extends ThriftService
}
}

private static final String SSL_TRUSTSTORE_PATH_PROPERTY = "watson.ssl.truststore.path";
private static final String SSL_TRUSTSTORE_PASSWORD_PROPERTY = "watson.ssl.truststore.password";

private static SSLContext sslContext = null;

private static SSLContext loadSSLContext() throws Exception {
if (sslContext == null) {
final String trustStorePath = System.getProperty(SSL_TRUSTSTORE_PATH_PROPERTY);
final String trustStorePassword = System.getProperty(SSL_TRUSTSTORE_PASSWORD_PROPERTY);

if (trustStorePath == null || trustStorePassword == null) {
throw new IllegalArgumentException("Truststore settings not found in system properties");
}

final KeyStore trustStore = KeyStore.getInstance("JKS");
try (FileInputStream trustStoreStream = new FileInputStream(trustStorePath)) {
trustStore.load(trustStoreStream, trustStorePassword.toCharArray());
}

final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(trustStore);

sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustManagerFactory.getTrustManagers(), null);
}
return sslContext;
}

private PayloadProcessor initPayloadProcessor() {
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) {
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
try {
Expand All @@ -441,7 +472,17 @@ private PayloadProcessor initPayloadProcessor() {
String modelId = uri.getQuery();
String method = uri.getFragment();
if ("http".equals(processorName)) {
logger.info("Initializing HTTP payload processor");
processor = new RemotePayloadProcessor(uri);
} else if ("https".equals(processorName)) {
SSLContext sslContext;
try {
sslContext = loadSSLContext();
} catch (Exception missingAlgorithmException) {
throw new UncheckedIOException(new IOException(missingAlgorithmException));
}
logger.info("Initializing HTTPS payload processor");
processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters());
} else if ("logger".equals(processorName)) {
processor = new LoggingPayloadProcessor();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.grpc.Metadata;
Expand All @@ -42,11 +44,27 @@ public class RemotePayloadProcessor implements PayloadProcessor {

private final URI uri;

private final SSLContext sslContext;
private final SSLParameters sslParameters;

private final HttpClient client;

public RemotePayloadProcessor(URI uri) {
this(uri, null, null);
}

public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) {
this.uri = uri;
this.client = HttpClient.newHttpClient();
this.sslContext = sslContext;
this.sslParameters = sslParameters;
if (sslContext != null && sslParameters != null) {
this.client = HttpClient.newBuilder()
.sslContext(sslContext)
.sslParameters(sslParameters)
.build();
} else {
this.client = HttpClient.newHttpClient();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,55 @@

package com.ibm.watson.modelmesh.payload;

import java.io.IOException;
import java.net.URI;
import java.security.NoSuchAlgorithmException;

import io.grpc.Metadata;
import io.grpc.Status;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;

import static org.junit.jupiter.api.Assertions.assertFalse;

class RemotePayloadProcessorTest {

void testDestinationUnreachable() throws IOException {
URI uri = URI.create("http://this-does-not-exist:123");
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}

@Test
void testDestinationUnreachable() {
RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123"));
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
void testDestinationUnreachableHTTPS() throws IOException, NoSuchAlgorithmException {
URI uri = URI.create("https://this-does-not-exist:123");
SSLContext sslContext = SSLContext.getDefault();
SSLParameters sslParameters = sslContext.getDefaultSSLParameters();
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri, sslContext, sslParameters)) {
String id = "123";
String modelId = "456";
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
}
}
}

0 comments on commit 2405dba

Please sign in to comment.