Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.sql.planner.SerializablePlan;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.storage.StorageEngine;
import org.opensearch.sql.utils.DeserializationFilterUtil;

/**
* This class is entry point to paged requests. It is responsible to cursor serialization and
Expand Down Expand Up @@ -88,6 +89,14 @@ protected Serializable deserialize(String code) {
new GZIPInputStream(new ByteArrayInputStream(HashCode.fromString(code).asBytes()));
ObjectInputStream objectInput =
new CursorDeserializationStream(new ByteArrayInputStream(gzip.readAllBytes()));
objectInput.setObjectInputFilter(
DeserializationFilterUtil.createFilter(
"org.opensearch.sql.planner.physical.*;"
+ "org.opensearch.sql.opensearch.storage.scan.*;"
+ "org.opensearch.sql.opensearch.data.type.*;"
+ "org.opensearch.sql.executor.pagination.*;"
+ "org.opensearch.sql.executor.QueryType;"
+ "org.opensearch.sql.utils.*;"));
return (Serializable) objectInput.readObject();
} catch (Exception e) {
throw new IllegalStateException("Failed to deserialize object", e);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.utils;

import java.io.ObjectInputFilter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/** Utility class for creating deserialization filters with logging. */
public class DeserializationFilterUtil {
private static final Logger LOG = LogManager.getLogger(DeserializationFilterUtil.class);

/** Base allowlist shared across all serializers. */
private static final String BASE_ALLOWLIST =
"org.opensearch.sql.expression.**;"
+ "org.opensearch.sql.data.**;"
+ "org.opensearch.sql.executor.QueryType;"
+ "org.opensearch.sql.opensearch.data.type.*;"
+ "java.lang.Object;"
+ "java.lang.String;"
+ "java.lang.Number;"
+ "java.lang.Integer;"
+ "java.lang.Long;"
+ "java.lang.Double;"
+ "java.lang.Float;"
+ "java.lang.Short;"
+ "java.lang.Byte;"
+ "java.lang.Boolean;"
+ "java.lang.Character;"
+ "java.lang.Enum;"
+ "java.util.ArrayList;"
+ "java.util.Arrays$ArrayList;"
+ "java.util.LinkedHashMap;"
+ "java.util.HashMap;"
+ "java.util.Collections$*;"
+ "java.util.ImmutableCollections$*;"
+ "java.util.CollSer;"
+ "java.util.Map$Entry;"
+ "java.io.Serializable;"
+ "java.lang.invoke.SerializedLambda;"
+ "java.math.BigDecimal;"
+ "java.math.BigInteger;"
+ "java.time.**;"
+ "com.google.common.collect.**;";

/**
* Creates a logging filter that wraps the provided filter and logs rejected classes.
*
* @param filter The underlying filter to wrap.
* @return A filter that logs rejections.
*/
public static ObjectInputFilter createLoggingFilter(ObjectInputFilter filter) {
return info -> {
ObjectInputFilter.Status status = filter.checkInput(info);
if (status == ObjectInputFilter.Status.REJECTED && info.serialClass() != null) {
LOG.warn("Deserialization filter rejected class: {}", info.serialClass().getName());
}
return status;
};
}

/**
* Creates a filter with the base allowlist plus additional patterns.
*
* @param additionalPatterns Additional patterns to append to the base allowlist.
* @return A logging filter with the combined allowlist.
*/
public static ObjectInputFilter createFilter(String additionalPatterns) {
String fullPattern = BASE_ALLOWLIST + additionalPatterns + "!*";
return createLoggingFilter(ObjectInputFilter.Config.createFilter(fullPattern));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ void resolveObject() {
assertSame(object, cds.resolveObject(object));
}

@Test
@SneakyThrows
void deserialize_rejects_disallowed_class() {
String serialized = serialize(new java.net.URL("http://example.com"));
var exception = assertThrows(IllegalStateException.class, () -> deserialize(serialized));
assertTrue(exception.getMessage().contains("Failed to deserialize"));
}

// Helpers and auxiliary classes section below

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.io.ObjectOutputStream;
import java.util.Base64;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.utils.DeserializationFilterUtil;

/** Default serializer that (de-)serialize expressions by JDK serialization. */
public class DefaultExpressionSerializer implements ExpressionSerializer {
Expand All @@ -34,6 +35,7 @@ public Expression deserialize(String code) {
try {
ByteArrayInputStream input = new ByteArrayInputStream(Base64.getDecoder().decode(code));
ObjectInputStream objectInput = new ObjectInputStream(input);
objectInput.setObjectInputFilter(DeserializationFilterUtil.createFilter(""));
return (Expression) objectInput.readObject();
} catch (Exception e) {
throw new IllegalStateException("Failed to deserialize expression code: " + code, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.calcite.CalcitePlanContext;
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
import org.opensearch.sql.opensearch.executor.OpenSearchExecutionEngine.OperatorTable;
import org.opensearch.sql.utils.DeserializationFilterUtil;

/**
* A serializer that (de-)serializes Calcite RexNode, RelDataType and OpenSearch field mapping.
Expand Down Expand Up @@ -120,6 +121,7 @@ public RexNode deserialize(String struct) {
try {
ByteArrayInputStream input = new ByteArrayInputStream(Base64.getDecoder().decode(struct));
ObjectInputStream objectInput = new ObjectInputStream(input);
objectInput.setObjectInputFilter(DeserializationFilterUtil.createFilter(""));
exprStr = (String) objectInput.readObject();

// Deserialize RelDataType and RexNode by JSON
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.DSL.literal;
import static org.opensearch.sql.expression.DSL.ref;
Expand Down Expand Up @@ -82,4 +83,16 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
public void cannot_deserialize_illegal_expression_code() {
assertThrows(IllegalStateException.class, () -> serializer.deserialize("hello world"));
}

@Test
public void deserialize_rejects_disallowed_class() throws Exception {
java.io.ByteArrayOutputStream output = new java.io.ByteArrayOutputStream();
java.io.ObjectOutputStream objectOutput = new java.io.ObjectOutputStream(output);
objectOutput.writeObject(new java.net.URL("http://example.com"));
objectOutput.flush();
String encoded = java.util.Base64.getEncoder().encodeToString(output.toByteArray());
var exception =
assertThrows(IllegalStateException.class, () -> serializer.deserialize(encoded));
assertTrue(exception.getMessage().contains("Failed to deserialize"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;

import com.google.common.collect.ImmutableRangeSet;
Expand Down Expand Up @@ -348,4 +349,16 @@ void testSerializeAndDeserializeSearch() {
assertEquals(List.of(2, 0, 0, 2), helper.sources);
assertEquals(List.of(20, "Number", "Number", 10), helper.digests);
}

@Test
void deserialize_rejects_disallowed_class() throws Exception {
java.io.ByteArrayOutputStream output = new java.io.ByteArrayOutputStream();
java.io.ObjectOutputStream objectOutput = new java.io.ObjectOutputStream(output);
objectOutput.writeObject(new java.net.URL("http://example.com"));
objectOutput.flush();
String encoded = java.util.Base64.getEncoder().encodeToString(output.toByteArray());
var exception =
assertThrows(IllegalStateException.class, () -> serializer.deserialize(encoded));
assertTrue(exception.getMessage().contains("Failed to deserialize"));
}
}
Loading