Skip to content

Commit

Permalink
WIP: Include the ability to define serialization filters
Browse files Browse the repository at this point in the history
  • Loading branch information
jccampanero committed Mar 29, 2021
1 parent 716e147 commit a24a1df
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
/*
* Copyright 2020 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package de.javakaffee.web.msm;

import java.io.ObjectInputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;

/**
* Convenient class for help in the definition of filters for serialization.
*
* Based on the analogous class of Keycloak: https://github.com/keycloak/keycloak/blob/a60cb65252aec21ea1899a73d5fc136ae5058383/common/src/main/java/org/keycloak/common/util/DelegatingSerializationFilter.java#L30
*/
public class DelegatingSerializationFilter {
private static final Log LOG = LogFactory.getLog( JavaSerializationTranscoder.class );

private static final SerializationFilterAdapter serializationFilterAdapter = isJava6To8() ? createOnJava6To8Adapter() : createOnJavaAfter8Adapter();

private static boolean isJava6To8() {
List<String> olderVersions = Arrays.asList("1.6", "1.7", "1.8");
return olderVersions.contains(System.getProperty("java.specification.version"));
}

private DelegatingSerializationFilter() {
}

public static DelegatingSerializationFilter.FilterPatternBuilder builder() {
return new DelegatingSerializationFilter.FilterPatternBuilder();
}

private void setFilter(ObjectInputStream ois, String filterPattern) {
LOG.debug("Using: " + serializationFilterAdapter.getClass().getSimpleName());

if (serializationFilterAdapter.getObjectInputFilter(ois) == null) {
serializationFilterAdapter.setObjectInputFilter(ois, filterPattern);
}
}

interface SerializationFilterAdapter {

Object getObjectInputFilter(ObjectInputStream ois);

void setObjectInputFilter(ObjectInputStream ois, String filterPattern);
}

private static SerializationFilterAdapter createOnJava6To8Adapter() {
try {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
Class<?> objectInputFilterClass = cl.loadClass("sun.misc.ObjectInputFilter");
Class<?> objectInputFilterConfigClass = cl.loadClass("sun.misc.ObjectInputFilter$Config");
Method getObjectInputFilter = objectInputFilterConfigClass.getDeclaredMethod("getObjectInputFilter", ObjectInputStream.class);
Method setObjectInputFilter = objectInputFilterConfigClass.getDeclaredMethod("setObjectInputFilter", ObjectInputStream.class, objectInputFilterClass);
Method createFilter = objectInputFilterConfigClass.getDeclaredMethod("createFilter", String.class);
LOG.info("Using OnJava6To8 serialization filter adapter");
return new OnJava6To8(getObjectInputFilter, setObjectInputFilter, createFilter);
} catch (ClassNotFoundException e) {
// This can happen for older JDK updates.
LOG.warn("Could not configure SerializationFilterAdapter. For better security, it is highly recommended to upgrade to newer JDK version update!");
LOG.warn("For the Java 7, the recommended update is at least 131 (1.7.0_131 or newer). For the Java 8, the recommended update is at least 121 (1.8.0_121 or newer).");
LOG.warn("Error details", e);
return new EmptyFilterAdapter();
} catch (NoSuchMethodException e) {
// This can happen for older JDK updates.
LOG.warn("Could not configure SerializationFilterAdapter. For better security, it is highly recommended to upgrade to newer JDK version update!");
LOG.warn("For the Java 7, the recommended update is at least 131 (1.7.0_131 or newer). For the Java 8, the recommended update is at least 121 (1.8.0_121 or newer).");
LOG.warn("Error details", e);
return new EmptyFilterAdapter();
}
}

private static SerializationFilterAdapter createOnJavaAfter8Adapter() {
try {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
Class<?> objectInputFilterClass = cl.loadClass("java.io.ObjectInputFilter");
Class<?> objectInputFilterConfigClass = cl.loadClass("java.io.ObjectInputFilter$Config");
Class<?> objectInputStreamClass = cl.loadClass("java.io.ObjectInputStream");
Method getObjectInputFilter = objectInputStreamClass.getDeclaredMethod("getObjectInputFilter");
Method setObjectInputFilter = objectInputStreamClass.getDeclaredMethod("setObjectInputFilter", objectInputFilterClass);
Method createFilter = objectInputFilterConfigClass.getDeclaredMethod("createFilter", String.class);
LOG.info("Using OnJavaAfter8 serialization filter adapter");
return new OnJavaAfter8(getObjectInputFilter, setObjectInputFilter, createFilter);
} catch (ClassNotFoundException e) {
// This can happen for older JDK updates.
LOG.warn("Could not configure SerializationFilterAdapter. For better security, it is highly recommended to upgrade to newer JDK version update!");
LOG.warn("Error details", e);
return new EmptyFilterAdapter();
} catch (NoSuchMethodException e) {
// This can happen for older JDK updates.
LOG.warn("Could not configure SerializationFilterAdapter. For better security, it is highly recommended to upgrade to newer JDK version update!");
LOG.warn("Error details", e);
return new EmptyFilterAdapter();
}
}

// If codebase stays on Java 8 for a while you could use Java 8 classes directly without reflection
static class OnJava6To8 implements SerializationFilterAdapter {

private final Method getObjectInputFilterMethod;
private final Method setObjectInputFilterMethod;
private final Method createFilterMethod;

private OnJava6To8(Method getObjectInputFilterMethod, Method setObjectInputFilterMethod, Method createFilterMethod) {
this.getObjectInputFilterMethod = getObjectInputFilterMethod;
this.setObjectInputFilterMethod = setObjectInputFilterMethod;
this.createFilterMethod = createFilterMethod;
}

public Object getObjectInputFilter(ObjectInputStream ois) {
try {
return getObjectInputFilterMethod.invoke(null, ois);
} catch (IllegalAccessException e) {
LOG.warn("Could not read ObjectFilter from ObjectInputStream: " + e.getMessage());
return null;
} catch (InvocationTargetException e) {
LOG.warn("Could not read ObjectFilter from ObjectInputStream: " + e.getMessage());
return null;
}
}

public void setObjectInputFilter(ObjectInputStream ois, String filterPattern) {
try {
Object objectFilter = createFilterMethod.invoke(null, filterPattern);
setObjectInputFilterMethod.invoke(null, ois, objectFilter);
} catch (IllegalAccessException e) {
LOG.warn("Could not set ObjectFilter: " + e.getMessage());
} catch (InvocationTargetException e) {
LOG.warn("Could not set ObjectFilter: " + e.getMessage());
}
}
}


static class EmptyFilterAdapter implements SerializationFilterAdapter {

@Override
public Object getObjectInputFilter(ObjectInputStream ois) {
return null;
}

@Override
public void setObjectInputFilter(ObjectInputStream ois, String filterPattern) {

}

}


// If codebase moves to Java 9+ could use Java 9+ classes directly without reflection and keep the old variant with reflection
static class OnJavaAfter8 implements SerializationFilterAdapter {

private final Method getObjectInputFilterMethod;
private final Method setObjectInputFilterMethod;
private final Method createFilterMethod;

private OnJavaAfter8(Method getObjectInputFilterMethod, Method setObjectInputFilterMethod, Method createFilterMethod) {
this.getObjectInputFilterMethod = getObjectInputFilterMethod;
this.setObjectInputFilterMethod = setObjectInputFilterMethod;
this.createFilterMethod = createFilterMethod;
}

public Object getObjectInputFilter(ObjectInputStream ois) {
try {
return getObjectInputFilterMethod.invoke(ois);
} catch (IllegalAccessException e) {
LOG.warn("Could not read ObjectFilter from ObjectInputStream: " + e.getMessage());
return null;
} catch (InvocationTargetException e) {
LOG.warn("Could not read ObjectFilter from ObjectInputStream: " + e.getMessage());
return null;
}
}

public void setObjectInputFilter(ObjectInputStream ois, String filterPattern) {
try {
Object objectFilter = createFilterMethod.invoke(ois, filterPattern);
setObjectInputFilterMethod.invoke(ois, objectFilter);
} catch (IllegalAccessException e) {
LOG.warn("Could not set ObjectFilter: " + e.getMessage());
} catch (InvocationTargetException e) {
LOG.warn("Could not set ObjectFilter: " + e.getMessage());
}
}
}


public static class FilterPatternBuilder {

private Set<Class> classes = new HashSet();
private Set<String> patterns = new HashSet();

public FilterPatternBuilder() {
// Add "java.util" package by default (contains all the basic collections)
addAllowedPattern("java.util.*");
}

/**
* This is used when the caller of this method can't use the {@link #addAllowedClass(Class)}. For example because the
* particular is private or it is not available at the compile time. Or when adding the whole package like "java.util.*"
*
* @param pattern
* @return
*/
public FilterPatternBuilder addAllowedPattern(String pattern) {
this.patterns.add(pattern);
return this;
}

public FilterPatternBuilder addAllowedClass(Class javaClass) {
this.classes.add(javaClass);
return this;
}

@Override
public String toString() {
StringBuilder builder = new StringBuilder();

for (Class javaClass : classes) {
builder.append(javaClass.getName()).append(";");
}
for (String pattern : patterns) {
builder.append(pattern).append(";");
}

builder.append("!*");

return builder.toString();
}

public void setFilter(ObjectInputStream ois) {
DelegatingSerializationFilter filter = new DelegatingSerializationFilter();
String filterPattern = this.toString();
filter.setFilter(ois, filterPattern);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ public class JavaSerializationTranscoder implements SessionAttributesTranscoder

/**
* Constructor.
*
* @param manager
* the manager
*/
public JavaSerializationTranscoder() {
this( null );
Expand Down Expand Up @@ -166,6 +163,8 @@ public ConcurrentMap<String, Object> deserializeAttributes(final byte[] in ) {
bis = new ByteArrayInputStream( in );
ois = createObjectInputStream( bis );

applySerializationFilter( ois );

final ConcurrentMap<String, Object> attributes = new ConcurrentHashMap<String, Object>();
final int n = ( (Integer) ois.readObject() ).intValue();
for ( int i = 0; i < n; i++ ) {
Expand Down Expand Up @@ -227,4 +226,31 @@ private void closeSilently( final InputStream is ) {
}
}

private void applySerializationFilter(ObjectInputStream ois) {
String serialFilter = getSerialFilter();
if (serialFilter != null) {
LOG.debug("Appying serialization filter: " + serialFilter);
applyDeserializationFilter(ois, serialFilter);
}
}

private String getSerialFilter() {
if ( this._manager == null ) {
LOG.debug("Manager not set. Returning null serial filter");
return null;
}

final String serialFilter = this._manager.getMemcachedSessionService().getSerialFilter();
LOG.debug("Serial filter for deserialization: " + serialFilter);
return serialFilter;
}

private void applyDeserializationFilter(ObjectInputStream ois, String serialFilter) {
DelegatingSerializationFilter
.builder()
.addAllowedPattern(serialFilter)
.setFilter(ois)
;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,22 @@ static enum LockStatus {
*/
private boolean _copyCollectionsForSerialization = false;

/**
* The pattern (as specified by ObjectInputFilter) that should be applied when performing
* session deserialization when using Java standard serialization.
*
* Different patterns should be separated by ;.
*
* The classes not included in the pattern will be rejected.
*
* <p>
* E.g. <code>!somepackage.*;someotherpackage.SomeClass</code>
* </p>
*
* If not set, the filter will not be applied
*/
private String _serialFilter = null;

private String _customConverterClassNames;

private boolean _enableStatistics = true;
Expand Down Expand Up @@ -1713,6 +1729,20 @@ public void setLockExpiration(final int lockExpiration) {
_lockExpiration = lockExpiration;
}


/**
* Return filter pattern to be applied in the serialization mechanism to prevent deserialization vulnerabilities.
*
* @return the serialFilter
*/
public String getSerialFilter() {
return _serialFilter;
}

public void setSerialFilter(String serialFilter) {
this._serialFilter = serialFilter;
}

// ----------------------- protected getters/setters for testing ------------------

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,28 @@ public void setCopyCollectionsForSerialization( final boolean copyCollectionsFor
_msm.setCopyCollectionsForSerialization( copyCollectionsForSerialization );
}


/**
* Set the pattern (as specified by ObjectInputFilter) that should be applied when performing
* session deserialization when using Java standard serialization.
*
* Different patterns should be separated by ;.
*
* The classes not included in the pattern will be rejected.
*
* <p>
* E.g. <code>!somepackage.*;someotherpackage.SomeClass</code>
* </p>
*
* If not set, the filter will not be applied
*
* @param serialFilter
* the filter pattern to set
*/
public void setSerialFilter( @Nullable final String serialFilter ) {
_msm.setSerialFilter( serialFilter );
}

/**
* Custom converter allow you to provide custom serialization of application specific
* types. Multiple converter classes are separated by comma (with optional space following the comma).
Expand Down

0 comments on commit a24a1df

Please sign in to comment.