diff --git a/async/src/main/java/com/ea/orbit/async/instrumentation/Transformer.java b/async/src/main/java/com/ea/orbit/async/instrumentation/Transformer.java index e2f76072e..8857ed69a 100644 --- a/async/src/main/java/com/ea/orbit/async/instrumentation/Transformer.java +++ b/async/src/main/java/com/ea/orbit/async/instrumentation/Transformer.java @@ -39,6 +39,7 @@ import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.FrameNode; +import org.objectweb.asm.tree.InsnList; import org.objectweb.asm.tree.InsnNode; import org.objectweb.asm.tree.LabelNode; import org.objectweb.asm.tree.LocalVariableNode; @@ -60,10 +61,14 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.objectweb.asm.Opcodes.*; @@ -102,7 +107,6 @@ public class Transformer implements ClassFileTransformer private static final Type COMPLETION_STAGE_TYPE = Type.getType("Ljava/util/concurrent/CompletionStage;"); private static final String COMPLETION_STAGE_RET = ")Ljava/util/concurrent/CompletionStage;"; - private static final Type OBJECT_TYPE = Type.getType(Object.class); private static final String _THIS = "_this"; private static final String TASK_DESCRIPTOR = "Lcom/ea/orbit/concurrent/Task;"; @@ -271,6 +275,7 @@ public byte[] transform(ClassReader cr) throws AnalyzerException // - in this order, the push order // - as an int array: int[parameter_index] = local_index int newMaxLocals = 0; + Set uninitializedObjects = new HashSet<>(); for (SwitchEntry se : switchEntries) { // clear the used state @@ -290,7 +295,13 @@ public byte[] transform(ClassReader cr) throws AnalyzerException { se.stackToNewLocal[j] = -1; } + // marks uninitialized objects + if (isUninitialized(value)) + { + uninitializedObjects.add(((FrameAnalyzer.ExtendedValue) value).insnNode); + } } + newMaxLocals = Math.max(iNewLocal, newMaxLocals); se.localToiArgument = new int[newMaxLocals]; Arrays.fill(se.localToiArgument, -1); @@ -302,6 +313,11 @@ public byte[] transform(ClassReader cr) throws AnalyzerException { mapLocalToLambdaArgument(original, se, arguments, iLocal, value); } + // marks uninitialized objects + if (isUninitialized(value)) + { + uninitializedObjects.add(((FrameAnalyzer.ExtendedValue) value).insnNode); + } } // maps the stack locals to arguments for (int j = 0; j < se.frame.getStackSize(); j++) @@ -323,7 +339,12 @@ public byte[] transform(ClassReader cr) throws AnalyzerException } } } - replaceObjectInitialization(classNode, original, nameUseCount, frames); + // only replaces object initialization + // if uninitialized objects are present in the stack during a await call. + if (uninitializedObjects.size() > 0) + { + replaceObjectInitialization(original, frames, uninitializedObjects); + } original.maxLocals = Math.max(original.maxLocals, newMaxLocals); } @@ -556,12 +577,18 @@ private int valueSize(final Value local) return local == null ? 1 : local.getSize(); } + // replacing only the initialization of objects that are uninitialized at the moment of an await call. void replaceObjectInitialization( - ClassNode classNode, final MethodNode methodNode, - final Map nameUseCount, final Frame[] frames) + final MethodNode methodNode, + final Frame[] frames, + final Set objectCreationNodes) { int originalLocals = methodNode.maxLocals; - int originalStack = methodNode.maxStack; + final Set uninitializedObjects = objectCreationNodes != null + ? objectCreationNodes + : Stream.of(methodNode.instructions.toArray()) + .filter(i -> i.getOpcode() == NEW) + .collect(Collectors.toSet()); // since we can't store uninitialized objects they have to be removed or replaced. // this works for bytecodes where the initialization is implemented like: @@ -575,7 +602,6 @@ void replaceObjectInitialization( // this conforms all cases of java derived bytecode that I'm aware of. // but it might not be always true. - // TODO: we should only replace the initialization of objects that are uninitialized at the moment of an await call. // replace frameNodes and constructor calls int index = 0; @@ -584,189 +610,176 @@ void replaceObjectInitialization( if (insnNode instanceof FrameNode) { FrameNode frameNode = (FrameNode) insnNode; - List stack = frameNode.stack; - if (stack != null) - { - for (int i = 0, l = stack.size(); i < l; i++) - { - final Object v = stack.get(i); - // replaces uninitialized object nodes with the actual type from the stack - if (v instanceof LabelNode) - { - AbstractInsnNode node = (AbstractInsnNode) v; - while (!(node instanceof TypeInsnNode && node.getOpcode() == NEW)) - { - node = node.getNext(); - } - stack.set(i, Type.getType(((TypeInsnNode) node).desc).getInternalName()); - } - } - } + frameNode.stack = replaceUninitializedFrameValues(uninitializedObjects, frameNode.stack); + frameNode.local = replaceUninitializedFrameValues(uninitializedObjects, frameNode.local); } else if (insnNode.getOpcode() == INVOKESPECIAL) { MethodInsnNode methodInsnNode = (MethodInsnNode) insnNode; if (methodInsnNode.name.equals("")) { - String oOwner = methodInsnNode.owner; - String oDesc = methodInsnNode.desc; - Type cType = Type.getObjectType(oOwner); - methodInsnNode.setOpcode(INVOKESTATIC); - methodInsnNode.name = "new$async$" + cType.getInternalName().replace('/', '_'); - methodInsnNode.owner = classNode.name; - Type[] oldArguments = Type.getArgumentTypes(methodInsnNode.desc); - - Frame frameBefore = frames[index]; - int targetStackIndex = frameBefore.getStackSize() - (1 + oldArguments.length); - final Value target = frameBefore.getStack(targetStackIndex); - int extraConsumed = 0; - // how many more to remove from the stack (normally exactly 1) - for (int i = targetStackIndex; --i >= 0 && target.equals(frameBefore.getStack(i)); ) - { - extraConsumed++; - } - - // test extraConsumed = 0 and extraConsumed = 2,3,4 - // done at #uninitializedInTheStackSingle and #uninitializedInTheStackMultipleExtraCopies + insnNode = replaceConstructorCall(methodNode, frames[index], uninitializedObjects, originalLocals, methodInsnNode); + } + } + } + // replace new calls + for (AbstractInsnNode insnNode = methodNode.instructions.getFirst(); insnNode != null; insnNode = insnNode.getNext()) + { + if (insnNode.getOpcode() == NEW && (uninitializedObjects.contains(insnNode))) + { + InsnNode newInsn = new InsnNode(ACONST_NULL); + methodNode.instructions.insertBefore(insnNode, newInsn); + methodNode.instructions.remove(insnNode); + insnNode = newInsn; + } + } + } - Type[] argumentTypes = new Type[oldArguments.length + 1 + extraConsumed]; - for (int i = 0; i <= extraConsumed; i++) - { - argumentTypes[i] = OBJECT_TYPE; - } - System.arraycopy(oldArguments, 0, argumentTypes, 1 + extraConsumed, oldArguments.length); - methodInsnNode.desc = Type.getMethodDescriptor(cType, argumentTypes); - final String key = methodInsnNode.name + methodInsnNode.desc; + private AbstractInsnNode replaceConstructorCall( + final MethodNode methodNode, + final Frame frame, final Set uninitializedObjects, final int originalLocals, + final MethodInsnNode methodInsnNode) + { + Type[] oldArguments = Type.getArgumentTypes(methodInsnNode.desc); + int targetStackIndex = frame.getStackSize() - (1 + oldArguments.length); + final FrameAnalyzer.ExtendedValue target = (FrameAnalyzer.ExtendedValue) frame.getStack(targetStackIndex); + if (uninitializedObjects != null && !uninitializedObjects.contains(target.insnNode)) + { + // only replaces the objects that need replacement + return methodInsnNode; + } - int stackSizeAfter = targetStackIndex; - // accounting for the extra pops - stackSizeAfter -= extraConsumed; + final InsnList instructions = methodNode.instructions; + // later, methodInsnNode is moved to the end of all inserted instructions + AbstractInsnNode currentInsn = methodInsnNode; - // must test with double in the locals: { Ljava/lang/Object; D Uninitialized } - // done at #uninitializedStoreWithWideVarsAndGaps - for (int j = 0; j < frameBefore.getLocals(); ) - { - // replaces all locals that used to reference the old value - BasicValue local = frameBefore.getLocal(j); - if (target.equals(local)) - { - methodNode.instructions.insert(insnNode, insnNode = new InsnNode(DUP)); - methodNode.instructions.insert(insnNode, insnNode = new VarInsnNode(ASTORE, j)); - } - j += local.getSize(); - } + // find the first reference to the target in the stack and saves everything after it. + int firstOccurrence = 0; + int[] stackToLocal = new int[frame.getStackSize()]; + Arrays.fill(stackToLocal, -1); + while (firstOccurrence < targetStackIndex && !target.equals(frame.getStack(firstOccurrence))) + { + firstOccurrence++; + } + // number of repetitions in the stack + int repetitions = 1; + for (int i = firstOccurrence + 1; i < frame.getStackSize() && target.equals(frame.getStack(i)); i++) + { + repetitions++; + } - // find first stack occurrence that needs to be replaced - int firstOccurrence = -1; - for (int j = 0; j < stackSizeAfter; j++) - { - // replaces all locals that used to reference the old value - BasicValue local = frameBefore.getStack(j); - if (target.equals(local)) - { - firstOccurrence = j; - break; - } - } - if (firstOccurrence >= 0) - { - // replaces it in the stack - // must test with double and long - int newMaxLocals = originalLocals; + // stores relevant stack values to new local variables + int newMaxLocals = 0; + int newObject = -1; + for (int iLocal = originalLocals, j = frame.getStackSize(); --j >= firstOccurrence; ) + { + BasicValue value = frame.getStack(j); + if (value.getType() == null) + { + // some uninitialized value, shouldn't happen, just in case + instructions.insert(currentInsn, currentInsn = new InsnNode(POP)); + instructions.insert(currentInsn, currentInsn = new InsnNode(ACONST_NULL)); + value = BasicValue.REFERENCE_VALUE; + } + if (!target.equals(value)) + { + stackToLocal[j] = iLocal; + // storing to a temporary local variable + instructions.insert(currentInsn, currentInsn = new VarInsnNode(value.getType().getOpcode(ISTORE), iLocal)); + iLocal += valueSize(value); + } + else + { + // position where we will put the new object if needed + if (j >= firstOccurrence + repetitions) + { + stackToLocal[j] = newObject != -1 ? newObject : (newObject = iLocal++); + } + instructions.insert(currentInsn, currentInsn = new InsnNode(POP)); + } + newMaxLocals = iLocal; + } + methodNode.maxLocals = Math.max(newMaxLocals, methodNode.maxLocals); - // stores the new object - methodNode.instructions.insert(insnNode, insnNode = new VarInsnNode(ASTORE, newMaxLocals)); - newMaxLocals++; + // creates the object + instructions.insert(currentInsn, currentInsn = new TypeInsnNode(NEW, target.getType().getInternalName())); - // stores everything (but the new refs) in the stack up to firstOccurrence - for (int j = stackSizeAfter; --j >= firstOccurrence; ) - { - BasicValue value = frameBefore.getStack(j); - if (!target.equals(value) && value.getType() != null) - { - methodNode.instructions.insert(insnNode, insnNode = new VarInsnNode(value.getType().getOpcode(ISTORE), newMaxLocals)); - newMaxLocals += value.getType().getSize(); - } - else - { - methodNode.instructions.insert(insnNode, insnNode = new InsnNode(POP)); - } - } - // restores the stack replacing the uninitialized refs - int iLocal = newMaxLocals; - for (int j = firstOccurrence; j < stackSizeAfter; j++) - { - BasicValue value = frameBefore.getStack(j); - if (target.equals(value)) - { - // replaces the old refs - methodNode.instructions.insert(insnNode, insnNode = new VarInsnNode(ALOAD, originalLocals)); - } - else - { - if (value.getType() != null) - { - iLocal -= value.getType().getSize(); - methodNode.instructions.insert(insnNode, insnNode = new VarInsnNode(value.getType().getOpcode(ILOAD), iLocal)); - } - else - { - methodNode.instructions.insert(insnNode, insnNode = new InsnNode(ACONST_NULL)); - } - } - } - methodNode.instructions.insert(insnNode, insnNode = new VarInsnNode(ALOAD, originalLocals)); - methodNode.maxLocals = Math.max(newMaxLocals, methodNode.maxLocals); - } + // stores the new object to all locals that should contain it, if any + for (int j = 0; j < frame.getLocals(); ) + { + // replaces all locals that used to reference the old value + BasicValue local = frame.getLocal(j); + if (target.equals(local)) + { + instructions.insert(currentInsn, currentInsn = new InsnNode(DUP)); + instructions.insert(currentInsn, currentInsn = new VarInsnNode(ASTORE, j)); + } + j += local.getSize(); + } - if (extraConsumed == 0) - { - methodNode.instructions.insert(insnNode, insnNode = new InsnNode(POP)); - } - else - { - for (int i = 1; i < extraConsumed; i++) - { - methodNode.instructions.insert(insnNode, insnNode = new InsnNode(DUP)); - } - } + if (firstOccurrence < targetStackIndex) + { + // duping instead of putting it to a local, just to look as regular java code. + for (int i = 1; i < repetitions; i++) + { + instructions.insert(currentInsn, currentInsn = new InsnNode(DUP)); + } + if (newObject != -1) + { + instructions.insert(currentInsn, currentInsn = new InsnNode(DUP)); + instructions.insert(currentInsn, currentInsn = new VarInsnNode(ASTORE, newObject)); + } + } - if (nameUseCount.get(key) == null) - { - nameUseCount.put(key, 1); - // thankfully the verifier doesn't check for the exceptions. - // nor the generic signature - final MethodVisitor mv = classNode.visitMethod(ACC_PRIVATE | ACC_STATIC, methodInsnNode.name, methodInsnNode.desc, null, null); - mv.visitCode(); - - mv.visitTypeInsn(NEW, oOwner); - mv.visitInsn(DUP); - int argSizes = 2; - // must test with long and double - for (int i = extraConsumed + 1; i < argumentTypes.length; i++) - { - mv.visitVarInsn((argumentTypes[i].getOpcode(ILOAD)), i); - argSizes += argumentTypes[i].getSize(); - } - mv.visitMethodInsn(INVOKESPECIAL, oOwner, "", oDesc, false); - mv.visitInsn(ARETURN); - mv.visitMaxs(argSizes, argSizes + 1); - mv.visitEnd(); - } - } + // restoring the the stack + for (int j = firstOccurrence + repetitions; j < frame.getStackSize(); j++) + { + final BasicValue value = frame.getStack(j); + if (value.getType() != null) + { + instructions.insert(currentInsn, currentInsn = new VarInsnNode(value.getType().getOpcode(ILOAD), stackToLocal[j])); + } + else + { + // uninitialized value + instructions.insert(currentInsn, currentInsn = new InsnNode(ACONST_NULL)); } } - // replace new calls - for (AbstractInsnNode insnNode = methodNode.instructions.getFirst(); insnNode != null; insnNode = insnNode.getNext()) + // move the constructor call to here + instructions.remove(methodInsnNode); + instructions.insert(currentInsn, currentInsn = methodInsnNode); + + // checks if there is stack reconstruction to do: + return currentInsn; + } + + private List replaceUninitializedFrameValues( + final Set uninitializedObjects, + final List list) + { + if (list == null) { - if (insnNode.getOpcode() == NEW) + return null; + } + final List newList = new ArrayList<>(list); + for (int i = 0, l = newList.size(); i < l; i++) + { + final Object v = newList.get(i); + // replaces uninitialized object nodes with the actual type from the newList + if (v instanceof LabelNode) { - InsnNode newInsn = new InsnNode(ACONST_NULL); - methodNode.instructions.insertBefore(insnNode, newInsn); - methodNode.instructions.remove(insnNode); - insnNode = newInsn; + AbstractInsnNode node = (AbstractInsnNode) v; + while (!(node instanceof TypeInsnNode && node.getOpcode() == NEW)) + { + node = node.getNext(); + } + if (uninitializedObjects.contains(node)) + { + newList.set(i, Type.getType(((TypeInsnNode) node).desc).getInternalName()); + } } } + return newList; } /** diff --git a/async/src/test/java/com/ea/orbit/async/instrumentation/ConstructorReplacementTest.java b/async/src/test/java/com/ea/orbit/async/instrumentation/ConstructorReplacementTest.java index 3d0111e5e..9508ae558 100644 --- a/async/src/test/java/com/ea/orbit/async/instrumentation/ConstructorReplacementTest.java +++ b/async/src/test/java/com/ea/orbit/async/instrumentation/ConstructorReplacementTest.java @@ -31,17 +31,59 @@ import com.ea.orbit.async.test.BaseTest; import org.junit.Test; +import org.objectweb.asm.tree.AbstractInsnNode; import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.MethodNode; -import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.junit.Assert.assertEquals; import static org.objectweb.asm.Opcodes.*; public class ConstructorReplacementTest extends BaseTest { + @Test + @SuppressWarnings("unchecked") + public void regularConstructorCall() throws Exception + { + // check that the constructor replacement is able to replace interleaved elements in the stack + // obs.: the java compiler doesn't usually produce code like this + // regular java compiler will do: + // -- { new dup dup ... } + // this tests what happens if: + // -- { new dup push_1 swap ... pop } + MethodNode mv = new MethodNode(ACC_PUBLIC, "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", null, new String[]{"java/lang/Exception"}); + mv.visitTypeInsn(NEW, "java/lang/Integer"); + mv.visitInsn(DUP); + mv.visitVarInsn(ALOAD, 1); + mv.visitTypeInsn(CHECKCAST, "java/lang/String"); + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Integer", "", "(Ljava/lang/String;)V", false); + mv.visitInsn(ARETURN); + mv.visitMaxs(4, 2); + + // without replacement + { + final ClassNode cn = createClassNode(Function.class, null); + mv.accept(cn); + assertEquals(101, createClass(Function.class, cn).apply("101")); + } + + // with replacement + { + final ClassNode cn = createClassNode(Function.class, null); + new Transformer().replaceObjectInitialization(mv, + new FrameAnalyzer().analyze(cn.name, mv), findConstructors(mv)); + mv.accept(cn); + // DevDebug.debugSaveTrace(cn.name, cn); + assertEquals(101, createClass(Function.class, cn).apply("101")); + } + } + + @Test @SuppressWarnings("unchecked") public void withInterleavedCopies() throws Exception @@ -52,7 +94,7 @@ public void withInterleavedCopies() throws Exception // -- { new dup dup ... } // this tests what happens if: // -- { new dup push_1 swap ... pop } - MethodNode mv = new MethodNode(ACC_PUBLIC, "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", null, new String[]{ "java/lang/Exception" }); + MethodNode mv = new MethodNode(ACC_PUBLIC, "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", null, new String[]{"java/lang/Exception"}); mv.visitTypeInsn(NEW, "java/lang/Integer"); mv.visitInsn(DUP); // interleaving copies in the stack @@ -78,21 +120,28 @@ public void withInterleavedCopies() throws Exception // with replacement { final ClassNode cn = createClassNode(Function.class, null); - new Transformer().replaceObjectInitialization(cn, mv, - new HashMap<>(), new FrameAnalyzer().analyze(cn.name, mv)); + new Transformer().replaceObjectInitialization(mv, + new FrameAnalyzer().analyze(cn.name, mv), findConstructors(mv)); mv.accept(cn); // DevDebug.debugSaveTrace(cn.name, cn); assertEquals(101, createClass(Function.class, cn).apply("101")); } } + public Set findConstructors(final MethodNode mv) + { + return Stream.of(mv.instructions.toArray()) + .filter(i -> i.getOpcode() == NEW) + .collect(Collectors.toSet()); + } + @Test @SuppressWarnings("unchecked") public void withMultipleInterleavedCopies() throws Exception { // check that the constructor replacement is able to replace interleaved elements in the stack // obs.: the java compiler doesn't usually produce code like this - MethodNode mv = new MethodNode(ACC_PUBLIC, "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", null, new String[]{ "java/lang/Exception" }); + MethodNode mv = new MethodNode(ACC_PUBLIC, "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", null, new String[]{"java/lang/Exception"}); mv.visitTypeInsn(NEW, "java/lang/Integer"); mv.visitInsn(DUP); mv.visitVarInsn(ASTORE, 2); @@ -135,8 +184,8 @@ public void withMultipleInterleavedCopies() throws Exception // with replacement { final ClassNode cn = createClassNode(Function.class, null); - new Transformer().replaceObjectInitialization(cn, mv, - new HashMap<>(), new FrameAnalyzer().analyze(cn.name, mv)); + new Transformer().replaceObjectInitialization(mv, + new FrameAnalyzer().analyze(cn.name, mv), findConstructors(mv)); mv.accept(cn); // DevDebug.debugSaveTrace(cn.name, cn); assertEquals(101, createClass(Function.class, cn).apply("101"));