Skip to content

Commit

Permalink
8297379: Enable the ByteBuffer path of Poly1305 optimizations
Browse files Browse the repository at this point in the history
Reviewed-by: sviswanathan, ascarpino, jnimeh
  • Loading branch information
vpaprotsk authored and Sandhya Viswanathan committed Dec 6, 2022
1 parent 1e46832 commit 203251f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 19 deletions.
29 changes: 23 additions & 6 deletions src/java.base/share/classes/com/sun/crypto/provider/Poly1305.java
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ void engineUpdate(ByteBuffer buf) {
BLOCK_LENGTH - blockOffset);

if (bytesToWrite >= BLOCK_LENGTH) {
// If bytes to write == BLOCK_LENGTH, then we have no
// left-over data from previous updates and we can create
// the IntegerModuloP directly from the input buffer.
processBlock(buf, bytesToWrite);
// Have at least one full block in the buf, process all full blocks
int blockMultipleLength = remaining & (~(BLOCK_LENGTH-1));
processMultipleBlocks(buf, blockMultipleLength);
remaining -= blockMultipleLength;
} else {
// We have some left-over data from previous updates, so
// copy that into the holding block until we get a full block.
Expand All @@ -138,9 +138,8 @@ void engineUpdate(ByteBuffer buf) {
processBlock(block, 0, BLOCK_LENGTH);
blockOffset = 0;
}
remaining -= bytesToWrite;
}

remaining -= bytesToWrite;
}
}

Expand Down Expand Up @@ -255,6 +254,24 @@ private void processMultipleBlocks(byte[] input, int offset, int length, long[]
}
}

private void processMultipleBlocks(ByteBuffer buf, int blockMultipleLength) {
if (buf.hasArray()) {
byte[] input = buf.array();
int offset = buf.arrayOffset() + buf.position();
long[] aLimbs = a.getLimbs();
long[] rLimbs = r.getLimbs();

processMultipleBlocksCheck(input, offset, blockMultipleLength, aLimbs, rLimbs);
processMultipleBlocks(input, offset, blockMultipleLength, aLimbs, rLimbs);
buf.position(offset + blockMultipleLength);
} else {
while (blockMultipleLength >= BLOCK_LENGTH) {
processBlock(buf, BLOCK_LENGTH);
blockMultipleLength -= BLOCK_LENGTH;
}
}
}

private static void processMultipleBlocksCheck(byte[] input, int offset, int length, long[] aLimbs, long[] rLimbs) {
Objects.checkFromIndexSize(offset, length, input.length);
final int numLimbs = 5; // Intrinsic expects exactly 5 limbs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Random;

import javax.crypto.spec.SecretKeySpec;

Expand All @@ -36,22 +37,23 @@
public class Poly1305IntrinsicFuzzTest {
public static void main(String[] args) throws Exception {
//Note: it might be useful to increase this number during development of new Poly1305 intrinsics
final int repeat = 100;
final int repeat = 1000;
for (int i = 0; i < repeat; i++) {
run();
}
System.out.println("Fuzz Success");
}

public static void run() throws Exception {
java.util.Random rnd = new java.util.Random();
Random rnd = new Random();
long seed = rnd.nextLong();
rnd.setSeed(seed);

byte[] key = new byte[32];
rnd.nextBytes(key);
int msgLen = rnd.nextInt(128, 4096); // x86_64 intrinsic requires 256 bytes minimum
byte[] message = new byte[msgLen];
rnd.nextBytes(message);

Poly1305 authenticator = new Poly1305();
Poly1305 authenticatorSlow = new Poly1305();
Expand All @@ -62,20 +64,22 @@ public static void run() throws Exception {
authenticator.engineInit(new SecretKeySpec(key, 0, 32, "Poly1305"), null);
authenticatorSlow.engineInit(new SecretKeySpec(key, 0, 32, "Poly1305"), null);

if (rnd.nextBoolean()) {
if (rnd.nextBoolean() && message.length > 16) {
// Prime just the buffer and/or accumulator (buffer can keep at most 16 bytes from previous engineUpdate)
int initDataLen = rnd.nextInt(8, 24);
authenticator.engineUpdate(message, 0, initDataLen);
slowUpdate(authenticatorSlow, message, 0, initDataLen);
int initDataLen = rnd.nextInt(1, 16);
int initDataOffset = rnd.nextInt(0, message.length - initDataLen);
fastUpdate(authenticator, rnd, message, initDataOffset, initDataLen);
slowUpdate(authenticatorSlow, message, initDataOffset, initDataLen);
}

if (rnd.nextBoolean()) {
// Multiple calls to engineUpdate
authenticator.engineUpdate(message, 0, message.length);
slowUpdate(authenticatorSlow, message, 0, message.length);
int initDataOffset = rnd.nextInt(0, message.length);
fastUpdate(authenticator, rnd, message, initDataOffset, message.length - initDataOffset);
slowUpdate(authenticatorSlow, message, initDataOffset, message.length - initDataOffset);
}

authenticator.engineUpdate(message, 0, message.length);
fastUpdate(authenticator, rnd, message, 0, message.length);
slowUpdate(authenticatorSlow, message, 0, message.length);

byte[] tag = authenticator.engineDoFinal();
Expand All @@ -87,9 +91,34 @@ public static void run() throws Exception {
}

static void slowUpdate(Poly1305 authenticator, byte[] message, int offset, int len) {
len = Math.min(message.length, offset + len);
for (int i = offset; i < len; i++) {
for (int i = offset; i < offset + len; i++) {
authenticator.engineUpdate(message[i]);
}
}

static void fastUpdate(Poly1305 authenticator, Random rnd, byte[] message, int offset, int len) {
ByteBuffer buf;
switch(rnd.nextInt(4)) {
case 0: // byte[]
authenticator.engineUpdate(message, offset, len);
break;
case 1: // ByteArray with backing array
buf = ByteBuffer.wrap(message, offset, len)
.order(java.nio.ByteOrder.LITTLE_ENDIAN);
authenticator.engineUpdate(buf);
break;
case 2: // ByteArray with backing array (non-zero position)
buf = ByteBuffer.wrap(message, 0, len+offset)
.order(java.nio.ByteOrder.LITTLE_ENDIAN)
.position(offset);
authenticator.engineUpdate(buf);
break;
case 3: // ByteArray without backing array (wont be sent to intrinsic)
buf = ByteBuffer.wrap(message, offset, len)
.asReadOnlyBuffer()
.order(java.nio.ByteOrder.LITTLE_ENDIAN);
authenticator.engineUpdate(buf);
break;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.annotations.Measurement;
import java.nio.ByteBuffer;

@Measurement(iterations = 3, time = 10)
@Warmup(iterations = 3, time = 10)
Expand All @@ -49,7 +50,7 @@ public class Poly1305DigestBench extends CryptoBase {

private byte[][] data;
int index = 0;
private static MethodHandle polyEngineInit, polyEngineUpdate, polyEngineFinal;
private static MethodHandle polyEngineInit, polyEngineUpdate, polyEngineUpdateBuf, polyEngineFinal;
private static Object polyObj;

static {
Expand All @@ -68,6 +69,10 @@ public class Poly1305DigestBench extends CryptoBase {
m.setAccessible(true);
polyEngineUpdate = lookup.unreflect(m);

m = polyClazz.getDeclaredMethod("engineUpdate", ByteBuffer.class);
m.setAccessible(true);
polyEngineUpdateBuf = lookup.unreflect(m);

m = polyClazz.getDeclaredMethod("engineDoFinal");
m.setAccessible(true);
polyEngineFinal = lookup.unreflect(m);
Expand All @@ -83,7 +88,7 @@ public void setup() {
}

@Benchmark
public byte[] digest() {
public byte[] digestBytes() {
try {
byte[] d = data[index];
index = (index +1) % SET_SIZE;
Expand All @@ -94,4 +99,17 @@ public byte[] digest() {
throw new RuntimeException(ex);
}
}

@Benchmark
public byte[] digestBuffer() {
try {
byte[] d = data[index];
index = (index +1) % SET_SIZE;
polyEngineInit.invoke(polyObj, new SecretKeySpec(d, 0, 32, "Poly1305"), null);
polyEngineUpdateBuf.invoke(polyObj, ByteBuffer.wrap(d, 0, d.length));
return (byte[])polyEngineFinal.invoke(polyObj);
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
}
}

1 comment on commit 203251f

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.