Skip to content

Commit 6ec36d3

Browse files
ferakoczwangweij
authored andcommitted
8373059: Test sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java should pass on Aarch64
Reviewed-by: weijun, vpaprotski
1 parent a99f340 commit 6ec36d3

File tree

2 files changed

+63
-29
lines changed

2 files changed

+63
-29
lines changed

src/java.base/share/classes/sun/security/provider/ML_DSA.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,7 @@ boolean vectorNormBound(int[][] vec, int bound) {
15551555
return res;
15561556
}
15571557

1558-
// precondition: -2^31 * MONT_Q <= a, b < 2^31, -2^31 < a * b < 2^31 * MONT_Q
1558+
// precondition: -2^31 <= a, b < 2^31, -2^31 * MONT_Q <= a * b < 2^31 * MONT_Q
15591559
// computes a * b * 2^-32 mod MONT_Q
15601560
// the result is greater than -MONT_Q and less than MONT_Q
15611561
// See e.g. Algorithm 3 in https://eprint.iacr.org/2018/039.pdf

test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java renamed to test/jdk/sun/security/provider/pqc/ML_DSA_Intrinsic_Test.java

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,17 @@
3838
*/
3939
/*
4040
* @test
41-
* @comment This test should be reenabled on aarch64
42-
* @requires os.simpleArch == "x64"
4341
* @library /test/lib
4442
* @key randomness
4543
* @modules java.base/sun.security.provider:+open
4644
* @run main ML_DSA_Intrinsic_Test
4745
*/
4846

49-
// To run manually: java --add-opens java.base/sun.security.provider=ALL-UNNAMED --add-exports java.base/sun.security.provider=ALL-UNNAMED
50-
// -XX:+UnlockDiagnosticVMOptions -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java
47+
// To run manually:
48+
// java --add-opens java.base/sun.security.provider=ALL-UNNAMED
49+
// --add-exports java.base/sun.security.provider=ALL-UNNAMED
50+
// -XX:+UnlockDiagnosticVMOptions -XX:+UseDilithiumIntrinsics
51+
// test/jdk/sun/security/provider/pqc/ML_DSA_Intrinsic_Test.java
5152

5253
public class ML_DSA_Intrinsic_Test {
5354
public static void main(String[] args) throws Throwable {
@@ -104,9 +105,10 @@ public static void main(String[] args) throws Throwable {
104105
m.setAccessible(true);
105106
MethodHandle inverseNttJava = lookup.unreflect(m);
106107

107-
// Hint: if test fails, you can hardcode the seed to make the test more reproducible
108108
Random rnd = new Random();
109109
long seed = rnd.nextLong();
110+
// Hint: if a test fails, it prints the seed, so you can hardcode
111+
// it here to reproduce the failure
110112
rnd.setSeed(seed);
111113
//Note: it might be useful to increase this number during development of new intrinsics
112114
final int repeat = 10000;
@@ -117,59 +119,80 @@ public static void main(String[] args) throws Throwable {
117119
int[] prod3 = new int[ML_DSA_N];
118120
int[] prod4 = new int[ML_DSA_N];
119121
for (int i = 0; i < repeat; i++) {
120-
// Hint: if test fails, you can hardcode the seed to make the test more reproducible:
121-
// rnd.setSeed(seed);
122-
testMult(prod1, prod2, coeffs1, coeffs2, mult, multJava, rnd, seed, i);
122+
testMult(prod1, prod2, coeffs1, coeffs2,
123+
mult, multJava, rnd, seed, i);
123124
testMultConst(prod1, prod2, multConst, multConstJava, rnd, seed, i);
124-
testDecompose(prod1, prod2, prod3, prod4, coeffs1, coeffs2, decompose, decomposeJava, rnd, seed, i);
125+
testDecompose(prod1, prod2, prod3, prod4, coeffs1, coeffs2,
126+
decompose, decomposeJava, rnd, seed, i);
125127
testAlmostNtt(coeffs1, coeffs2, almostNtt, almostNttJava, rnd, seed, i);
126128
testInverseNtt(coeffs1, coeffs2, inverseNtt, inverseNttJava, rnd, seed, i);
127129
}
128130
System.out.println("Fuzz Success");
129131
}
130132

131-
private static final int ML_DSA_N = 256;
132-
public static void testMult(int[] prod1, int[] prod2, int[] coeffs1, int[] coeffs2,
133+
public static void testMult(int[] prod1, int[] prod2,
134+
int[] coeffs1, int[] coeffs2,
133135
MethodHandle mult, MethodHandle multJava, Random rnd,
134136
long seed, int i) throws Throwable {
135137

136-
for (int j = 0; j<ML_DSA_N; j++) {
137-
coeffs1[j] = rnd.nextInt();
138-
coeffs2[j] = rnd.nextInt();
138+
// This method is always called with arrays whose elements are between
139+
// -ML_DSA_Q and ML_DSA_Q, so we only test for these here (although
140+
// both versions work fine with array element sizes that satisfy the
141+
// montMul() preconditions in sun.security.provider.ML_DSA.java
142+
for (int j = 0; j < ML_DSA_N; j++) {
143+
coeffs1[j] = rnd.nextInt(2 * ML_DSA_Q) - ML_DSA_Q;
144+
coeffs2[j] = rnd.nextInt(2 * ML_DSA_Q) - ML_DSA_Q;
139145
}
140146

141147
mult.invoke(prod1, coeffs1, coeffs2);
142148
multJava.invoke(prod2, coeffs1, coeffs2);
143149

144150
if (!Arrays.equals(prod1, prod2)) {
145-
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result mult mismatch: " + formatOf(prod1) + " != " + formatOf(prod2));
151+
// The Java version and the intrinsic version should not produce
152+
// the exact same result (although usually they do), it is enough
153+
// if the corresponding array elements are congruent modulo ML_DSA_Q
154+
boolean modQequal = true;
155+
for (int j = 0; j < ML_DSA_N; j++) {
156+
if (prod1[j] != prod2[j]) {
157+
modQequal &= (((prod1[j] - prod2[j]) % ML_DSA_Q) == 0);
158+
}
159+
}
160+
if (!modQequal) {
161+
throw new RuntimeException("[Seed " + seed + "@" + i
162+
+ "] Result mult mismatch: "
163+
+ formatOf(prod1) + "\n != " + formatOf(prod2));
164+
}
146165
}
147166
}
148167

149168
public static void testMultConst(int[] prod1, int[] prod2,
150169
MethodHandle multConst, MethodHandle multConstJava, Random rnd,
151170
long seed, int i) throws Throwable {
152171

153-
for (int j = 0; j<ML_DSA_N; j++) {
172+
for (int j = 0; j < ML_DSA_N; j++) {
154173
prod1[j] = prod2[j] = rnd.nextInt();
155174
}
156-
// Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf, one of the inputs is bound, which prevents overflows
157-
int dilithium_q = 8380417;
158-
int c = rnd.nextInt(dilithium_q);
175+
176+
// Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf,
177+
// one of the inputs is bound, which prevents overflows
178+
int c = rnd.nextInt(ML_DSA_Q);
159179

160180
multConst.invoke(prod1, c);
161181
multConstJava.invoke(prod2, c);
162182

163183
if (!Arrays.equals(prod1, prod2)) {
164-
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result multConst mismatch: " + formatOf(prod1) + " != " + formatOf(prod2));
184+
throw new RuntimeException("[Seed " + seed + "@" + i
185+
+ "] Result multConst mismatch: "
186+
+ formatOf(prod1) + " != " + formatOf(prod2));
165187
}
166188
}
167189

168-
public static void testDecompose(int[] low1, int[] high1, int[] low2, int[] high2, int[] coeffs1, int[] coeffs2,
190+
public static void testDecompose(int[] low1, int[] high1, int[] low2,
191+
int[] high2, int[] coeffs1, int[] coeffs2,
169192
MethodHandle decompose, MethodHandle decomposeJava, Random rnd,
170193
long seed, int i) throws Throwable {
171194

172-
for (int j = 0; j<ML_DSA_N; j++) {
195+
for (int j = 0; j < ML_DSA_N; j++) {
173196
coeffs1[j] = coeffs2[j] = rnd.nextInt();
174197
}
175198
int gamma2 = 95232;
@@ -182,41 +205,49 @@ public static void testDecompose(int[] low1, int[] high1, int[] low2, int[] high
182205
decomposeJava.invoke(coeffs2, low2, high2, 2 * gamma2, multiplier);
183206

184207
if (!Arrays.equals(low1, low2)) {
185-
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result low mismatch: " + formatOf(low1) + " != " + formatOf(low2));
208+
throw new RuntimeException("[Seed " + seed + "@" + i
209+
+ "] Result low mismatch: "
210+
+ formatOf(low1) + " != " + formatOf(low2));
186211
}
187212

188213
if (!Arrays.equals(high1, high2)) {
189-
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result high mismatch: " + formatOf(high1) + " != " + formatOf(high2));
214+
throw new RuntimeException("[Seed " + seed + "@" + i
215+
+ "] Result high mismatch: "
216+
+ formatOf(high1) + " != " + formatOf(high2));
190217
}
191218
}
192219

193220
public static void testAlmostNtt(int[] coeffs1, int[] coeffs2,
194221
MethodHandle almostNtt, MethodHandle almostNttJava, Random rnd,
195222
long seed, int i) throws Throwable {
196-
for (int j = 0; j<ML_DSA_N; j++) {
223+
for (int j = 0; j < ML_DSA_N; j++) {
197224
coeffs1[j] = coeffs2[j] = rnd.nextInt();
198225
}
199226

200227
almostNtt.invoke(coeffs1, MONT_ZETAS_FOR_VECTOR_NTT);
201228
almostNttJava.invoke(coeffs2);
202229

203230
if (!Arrays.equals(coeffs1, coeffs2)) {
204-
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result AlmostNtt mismatch: " + formatOf(coeffs1) + " != " + formatOf(coeffs2));
231+
throw new RuntimeException("[Seed " + seed + "@" + i
232+
+"] Result AlmostNtt mismatch: "
233+
+ formatOf(coeffs1) + " != " + formatOf(coeffs2));
205234
}
206235
}
207236

208237
public static void testInverseNtt(int[] coeffs1, int[] coeffs2,
209238
MethodHandle inverseNtt, MethodHandle inverseNttJava, Random rnd,
210239
long seed, int i) throws Throwable {
211-
for (int j = 0; j<ML_DSA_N; j++) {
240+
for (int j = 0; j < ML_DSA_N; j++) {
212241
coeffs1[j] = coeffs2[j] = rnd.nextInt();
213242
}
214243

215244
inverseNtt.invoke(coeffs1, MONT_ZETAS_FOR_VECTOR_INVERSE_NTT);
216245
inverseNttJava.invoke(coeffs2);
217246

218247
if (!Arrays.equals(coeffs1, coeffs2)) {
219-
throw new RuntimeException("[Seed "+seed+"@"+i+"] Result InverseNtt mismatch: " + formatOf(coeffs1) + " != " + formatOf(coeffs2));
248+
throw new RuntimeException("[Seed " + seed+ "@" + i
249+
+"] Result InverseNtt mismatch: "
250+
+ formatOf(coeffs1) + " != " + formatOf(coeffs2));
220251
}
221252
}
222253

@@ -230,6 +261,9 @@ private static CharSequence formatOf(int[] arr) {
230261
}
231262

232263
// Copied constants from sun.security.provider.ML_DSA
264+
private static final int ML_DSA_N = 256;
265+
private static final int ML_DSA_Q = 8380417;
266+
233267
private static final int[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int[]{
234268
-1976782, 846154, -1400424, -3937738, 1362209, 48306, -3919660, 554416,
235269
3545687, -1612842, 976891, -183443, 2286327, 420899, 2235985, 2939036,

0 commit comments

Comments
 (0)