Permalink
Browse files

Allow correct constructors to be selected with a custom serializabili…

…ty checker
  • Loading branch information...
1 parent 89721e0 commit bd4c170864e3b8086cab85f9a6e1afe6166850b4 @dmlloyd dmlloyd committed Apr 9, 2012
@@ -235,7 +235,11 @@ private Object clone(final Object orig, final boolean replace) throws IOExceptio
soo.doFinish();
((Externalizable) clone).readExternal(new StepObjectInput(steps));
} else if (serializabilityChecker.isSerializable(objClass)) {
- clone = info.callNonInitConstructor();
+ Class<?> nonSerializable;
+ for (nonSerializable = objClass.getSuperclass(); serializabilityChecker.isSerializable(nonSerializable); nonSerializable = nonSerializable.getSuperclass()) {
+ if (nonSerializable == Object.class) break;
+ }
+ clone = info.callNonInitConstructor(nonSerializable);
final Class<?> cloneClass = clone.getClass();
if (! (serializabilityChecker.isSerializable(cloneClass))) {
throw new NotSerializableException(cloneClass.getName());
@@ -23,7 +23,6 @@
package org.jboss.marshalling.reflect;
import java.io.ObjectInput;
-import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.InvocationTargetException;
@@ -42,6 +41,7 @@
import java.util.Comparator;
import java.util.Collections;
import java.util.HashMap;
+import java.util.IdentityHashMap;
import java.util.Map;
import sun.reflect.ReflectionFactory;
@@ -50,6 +50,7 @@
*/
public final class SerializableClass {
private static final ReflectionFactory reflectionFactory;
+ private static final SerializableClassRegistry REGISTRY = SerializableClassRegistry.getInstanceUnchecked();
static {
reflectionFactory = AccessController.doPrivileged(new PrivilegedAction<ReflectionFactory>() {
@@ -59,6 +60,7 @@ public ReflectionFactory run() {
});
}
+ private final IdentityHashMap<Class<?>, Constructor<?>> nonInitConstructors;
private final Class<?> subject;
private final Method writeObject;
private final Method writeReplace;
@@ -67,7 +69,6 @@ public ReflectionFactory run() {
private final Method readResolve;
private final Constructor<?> noArgConstructor;
private final Constructor<?> objectInputConstructor;
- private final Constructor<?> nonInitConstructor;
private final SerializableField[] fields;
private final Map<String, SerializableField> fieldsByName;
private final long effectiveSerialVersionUID;
@@ -84,6 +85,17 @@ public int compare(final SerializableField o1, final SerializableField o2) {
SerializableClass(Class<?> subject) {
this.subject = subject;
+ final IdentityHashMap<Class<?>, Constructor<?>> constructorMap = new IdentityHashMap<Class<?>, Constructor<?>>();
+ for (Class<?> t = subject.getSuperclass(); t != null; t = t.getSuperclass()) {
+ final SerializableClass lookedUp = REGISTRY.lookup(t);
+ final Constructor<?> constructor = lookedUp.noArgConstructor;
+ if (constructor != null) {
+ final Constructor newConstructor = reflectionFactory.newConstructorForSerialization(subject, constructor);
+ newConstructor.setAccessible(true);
+ constructorMap.put(t, newConstructor);
+ }
+ }
+ nonInitConstructors = constructorMap;
// private methods
Method writeObject = null;
Method readObject = null;
@@ -136,7 +148,7 @@ public int compare(final SerializableField o1, final SerializableField o2) {
if (readResolve == null || writeReplace == null) {
final Class<?> superclass = subject.getSuperclass();
if (superclass != null) {
- final SerializableClass superInfo = SerializableClassRegistry.getInstanceUnchecked().lookup(superclass);
+ final SerializableClass superInfo = REGISTRY.lookup(superclass);
final Method otherReadResolve = superInfo.readResolve;
if (readResolve == null && otherReadResolve != null && ! Modifier.isPrivate(otherReadResolve.getModifiers())) {
readResolve = otherReadResolve;
@@ -149,7 +161,7 @@ public int compare(final SerializableField o1, final SerializableField o2) {
}
Constructor<?> noArgConstructor = null;
Constructor<?> objectInputConstructor = null;
- for (Constructor<?> constructor : subject.getConstructors()) {
+ for (Constructor<?> constructor : subject.getDeclaredConstructors()) {
final Class<?>[] parameterTypes = constructor.getParameterTypes();
if (parameterTypes.length == 0) {
noArgConstructor = constructor;
@@ -166,7 +178,6 @@ public int compare(final SerializableField o1, final SerializableField o2) {
this.objectInputConstructor = objectInputConstructor;
this.readResolve = readResolve;
this.writeReplace = writeReplace;
- nonInitConstructor = lookupNonInitConstructor(subject);
final ObjectStreamClass objectStreamClass = ObjectStreamClass.lookup(subject);
effectiveSerialVersionUID = objectStreamClass == null ? 0L : objectStreamClass.getSerialVersionUID(); // todo find a better solution
final HashMap<String, SerializableField> fieldsByName = new HashMap<String, SerializableField>();
@@ -433,12 +444,12 @@ public Object callReadResolve(Object object) throws ObjectStreamException {
*
* @return {@code true} if there is such a constructor
*/
- public boolean hasNoArgConstructor() {
- return noArgConstructor != null;
+ public boolean hasPublicNoArgConstructor() {
+ return noArgConstructor != null && Modifier.isPublic(noArgConstructor.getModifiers());
}
/**
- * Invoke the public no-arg constructor on this class.
+ * Invoke the no-arg constructor on this class.
*
* @return the new instance
* @throws IOException if an I/O error occurs
@@ -453,7 +464,7 @@ public Object callNoArgConstructor() throws IOException {
* @return {@code true} if there is such a constructor
*/
public boolean hasObjectInputConstructor() {
- return objectInputConstructor != null;
+ return objectInputConstructor != null && Modifier.isPublic(objectInputConstructor.getModifiers());
}
/**
@@ -472,20 +483,23 @@ public Object callObjectInputConstructor(final ObjectInput objectInput) throws I
*
* @return whether this class has a non-init constructor
*/
- public boolean hasNoInitConstructor() {
- return nonInitConstructor != null;
+ public boolean hasNoInitConstructor(Class<?> target) {
+ return nonInitConstructors.containsKey(target);
}
/**
* Invoke the non-init constructor on this class.
*
* @return the new instance
*/
- public Object callNonInitConstructor() {
- return invokeConstructorNoException(nonInitConstructor);
+ public Object callNonInitConstructor(Class<?> target) {
+ return invokeConstructorNoException(nonInitConstructors.get(target));
}
- private static Object invokeConstructor(Constructor<?> constructor, Object... args) throws IOException {
+ private static <T> T invokeConstructor(Constructor<T> constructor, Object... args) throws IOException {
+ if (constructor == null) {
+ throw new IllegalArgumentException("No matching constructor");
+ }
try {
return constructor.newInstance(args);
} catch (InvocationTargetException e) {
@@ -506,7 +520,10 @@ private static Object invokeConstructor(Constructor<?> constructor, Object... ar
}
}
- private static Object invokeConstructorNoException(Constructor<?> constructor, Object... args) {
+ private static <T> T invokeConstructorNoException(Constructor<T> constructor, Object... args) {
+ if (constructor == null) {
+ throw new IllegalArgumentException("No matching constructor");
+ }
try {
return constructor.newInstance(args);
} catch (InvocationTargetException e) {
@@ -543,27 +560,6 @@ public long getEffectiveSerialVersionUID() {
return subject;
}
- @SuppressWarnings("unchecked")
- private static <T> Constructor<T> lookupNonInitConstructor(final Class<T> subject) {
- Class<? super T> current = subject;
- for (; Serializable.class.isAssignableFrom(current); current = current.getSuperclass());
- final Constructor<? super T> topConstructor;
- try {
- topConstructor = current.getDeclaredConstructor();
- } catch (NoSuchMethodException e) {
- return null;
- }
- topConstructor.setAccessible(true);
- final Constructor<T> generatedConstructor = (Constructor<T>) reflectionFactory.newConstructorForSerialization(subject, topConstructor);
- generatedConstructor.setAccessible(true);
- return generatedConstructor;
- }
-
- @SuppressWarnings("unchecked")
- <T> Constructor<T> getNoInitConstructor() {
- return (Constructor<T>) nonInitConstructor;
- }
-
@SuppressWarnings("unchecked")
<T> Constructor<T> getNoArgConstructor() {
return (Constructor<T>) noArgConstructor;
@@ -1191,7 +1191,13 @@ protected Object doReadNewObject(final int streamClassType, final boolean unshar
switch (classType) {
case ID_PROXY_CLASS: {
final Class<?> type = descriptor.getType();
- final Object obj = registry.lookup(type).callNonInitConstructor();
+ final Class<?> nonSerializableSuperclass;
+ if (descriptor instanceof SerializableClassDescriptor) {
+ nonSerializableSuperclass = ((SerializableClassDescriptor) descriptor).getNonSerializableSuperclass();
+ } else {
+ nonSerializableSuperclass = Object.class;
+ }
+ final Object obj = registry.lookup(type).callNonInitConstructor(nonSerializableSuperclass);
final int idx = instanceCache.size();
instanceCache.add(obj);
try {
@@ -1211,7 +1217,7 @@ protected Object doReadNewObject(final int streamClassType, final boolean unshar
case ID_SERIALIZABLE_CLASS: {
final SerializableClassDescriptor serializableClassDescriptor = (SerializableClassDescriptor) descriptor;
final SerializableClass serializableClass = serializableClassDescriptor.getSerializableClass();
- final Object obj = serializableClass.callNonInitConstructor();
+ final Object obj = serializableClass.callNonInitConstructor(serializableClassDescriptor.getNonSerializableSuperclass());
final int idx = instanceCache.size();
instanceCache.add(obj);
doInitSerializable(obj, serializableClassDescriptor);
@@ -1230,7 +1236,7 @@ protected Object doReadNewObject(final int streamClassType, final boolean unshar
final Externalizable obj;
if (serializableClass.hasObjectInputConstructor()) {
obj = (Externalizable) serializableClass.callObjectInputConstructor(blockUnmarshaller);
- } else if (serializableClass.hasNoArgConstructor()) {
+ } else if (serializableClass.hasPublicNoArgConstructor()) {
obj = (Externalizable) serializableClass.callNoArgConstructor();
} else {
throw new InvalidClassException(type.getName(), "Class is non-public or has no public no-arg constructor");
@@ -50,4 +50,13 @@ public String toString() {
return String.format("%s for %s (id %x02) extends %s", getClass().getSimpleName(), getType(), Integer.valueOf(getTypeID()), superClassDescriptor);
}
}
+
+ public Class<?> getNonSerializableSuperclass() {
+ final ClassDescriptor descriptor = getSuperClassDescriptor();
+ if (descriptor instanceof SerializableClassDescriptor) {
+ return ((SerializableClassDescriptor) descriptor).getNonSerializableSuperclass();
+ } else {
+ return descriptor.getType();
+ }
+ }
}
@@ -258,7 +258,7 @@ Object doReadObject(int leadByte, final boolean unshared) throws IOException, Cl
if ((descriptor.getFlags() & SC_EXTERNALIZABLE) != 0) {
if (sc.hasObjectInputConstructor()) {
obj = sc.callObjectInputConstructor(blockUnmarshaller);
- } else if (sc.hasNoArgConstructor()) {
+ } else if (sc.hasPublicNoArgConstructor()) {
obj = sc.callNoArgConstructor();
} else {
throw new InvalidClassException(objClass.getName(), "Class is non-public or has no public no-arg constructor");
@@ -279,7 +279,11 @@ Object doReadObject(int leadByte, final boolean unshared) throws IOException, Cl
throw new InvalidObjectException("Created object should be Externalizable but it is not");
}
} else {
- obj = sc.callNonInitConstructor();
+ Class<?> nonSerializable;
+ for (nonSerializable = objClass.getSuperclass(); serializabilityChecker.isSerializable(nonSerializable); nonSerializable = nonSerializable.getSuperclass()) {
+ if (nonSerializable == Object.class) break;
+ }
+ obj = sc.callNonInitConstructor(nonSerializable);
if (obj instanceof Externalizable) {
throw new InvalidObjectException("Created object should not be Externalizable but it is");
}

0 comments on commit bd4c170

Please sign in to comment.