Skip to content
108 changes: 76 additions & 32 deletions src/java.base/share/classes/sun/security/provider/ML_DSA.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import jdk.internal.vm.annotation.IntrinsicCandidate;
import sun.security.provider.SHA3.SHAKE128;
import sun.security.provider.SHA3.SHAKE256;
import sun.security.provider.SHA3Parallel.Shake128Parallel;

import java.security.InvalidAlgorithmParameterException;
import java.security.MessageDigest;
import java.security.InvalidKeyException;
import java.security.SignatureException;
Expand All @@ -45,7 +47,7 @@ public class ML_DSA {
private static final int ML_DSA_Q = 8380417;
private static final int ML_DSA_N = 256;
private static final int SHAKE256_BLOCK_SIZE = 136; // the block length for SHAKE256

private static final int SHAKE128_BLOCK_SIZE = 168; // the block length for SHAKE128
private final int A_SEED_LEN = 32;
private final int S1S2_SEED_LEN = 64;
private final int K_LEN = 32;
Expand Down Expand Up @@ -1090,43 +1092,85 @@ private void sampleInBall(int[] c, byte[] rho) {
}
}

int[][][] generateA(byte[] seed) {
int blockSize = 168; // the size of one block of SHAKE128 output
var xof = new SHAKE128(0);
byte[] xofSeed = new byte[A_SEED_LEN + 2];
System.arraycopy(seed, 0, xofSeed, 0, A_SEED_LEN);
private int[][][] generateA(byte[] seed) {
int[][][] a = new int[mlDsa_k][mlDsa_l][];

for (int i = 0; i < mlDsa_k; i++) {
for (int j = 0; j < mlDsa_l; j++) {
xofSeed[A_SEED_LEN] = (byte) j;
xofSeed[A_SEED_LEN + 1] = (byte) i;
xof.reset();
xof.update(xofSeed);

byte[] rawAij = new byte[blockSize];
int[] aij = new int[ML_DSA_N];
int ofs = 0;
int rawOfs = blockSize;
int tmp;
while (ofs < ML_DSA_N) {
if (rawOfs == blockSize) {
// works because 3 divides blockSize (=168)
xof.squeeze(rawAij, 0, blockSize);
rawOfs = 0;
}
tmp = (rawAij[rawOfs] & 0xFF) +
((rawAij[rawOfs + 1] & 0xFF) << 8) +
((rawAij[rawOfs + 2] & 0x7F) << 16);
rawOfs += 3;
if (tmp < ML_DSA_Q) {
aij[ofs] = tmp;
ofs++;
int nrPar = 2;
int rhoLen = seed.length;
byte[] seedBuf = new byte[SHAKE128_BLOCK_SIZE];
System.arraycopy(seed, 0, seedBuf, 0, seed.length);
seedBuf[rhoLen + 2] = 0x1F;
seedBuf[SHAKE128_BLOCK_SIZE - 1] = (byte)0x80;
byte[][] xofBufArr = new byte[nrPar][SHAKE128_BLOCK_SIZE];
int[] iIndex = new int[nrPar];
int[] jIndex = new int[nrPar];

int[] parsedBuf = new int[SHAKE128_BLOCK_SIZE / 3];

int parInd = 0;
boolean allDone;
int[] ofs = new int[nrPar];
Arrays.fill(ofs, 0);
int[][] aij = new int[nrPar][];
try {
Shake128Parallel parXof = new Shake128Parallel(xofBufArr);

for (int i = 0; i < mlDsa_k; i++) {
for (int j = 0; j < mlDsa_l; j++) {
xofBufArr[parInd] = seedBuf.clone();
xofBufArr[parInd][rhoLen] = (byte) j;
xofBufArr[parInd][rhoLen + 1] = (byte) i;
iIndex[parInd] = i;
jIndex[parInd] = j;
ofs[parInd] = 0;
aij[parInd] = new int[ML_DSA_N];
parInd++;

if ((parInd == nrPar) ||
((i == mlDsa_k - 1) && (j == mlDsa_l - 1))) {
parXof.reset(xofBufArr);

allDone = false;
while (!allDone) {
allDone = true;
parXof.squeezeBlock();
for (int k = 0; k < parInd; k++) {
int parsedOfs = 0;
int tmp;
if (ofs[k] < ML_DSA_N) {
for (int l = 0; l < SHAKE128_BLOCK_SIZE; l += 3) {
byte[] rawBuf = xofBufArr[k];
parsedBuf[l / 3] = (rawBuf[l] & 0xFF) +
((rawBuf[l + 1] & 0xFF) << 8) +
((rawBuf[l + 2] & 0x7F) << 16);
}
}
while ((ofs[k] < ML_DSA_N) &&
(parsedOfs < SHAKE128_BLOCK_SIZE / 3)) {
tmp = parsedBuf[parsedOfs++];
if (tmp < ML_DSA_Q) {
aij[k][ofs[k]] = tmp;
ofs[k]++;
}
}
if (ofs[k] < ML_DSA_N) {
allDone = false;
}
}
}

for (int k = 0; k < parInd; k++) {
a[iIndex[k]][jIndex[k]] = aij[k];
}
parInd = 0;
}
}
a[i][j] = aij;
}
} catch (InvalidAlgorithmParameterException e) {
// This should never happen since xofBufArr is of the correct size
throw new RuntimeException("Internal error.");
}

return a;
}

Expand Down