3838 */
3939/*
4040 * @test
41- * @comment This test should be reenabled on aarch64
42- * @requires os.simpleArch == "x64"
4341 * @library /test/lib
4442 * @key randomness
4543 * @modules java.base/sun.security.provider:+open
4644 * @run main ML_DSA_Intrinsic_Test
4745 */
4846
49- // To run manually: java --add-opens java.base/sun.security.provider=ALL-UNNAMED --add-exports java.base/sun.security.provider=ALL-UNNAMED
50- // -XX:+UnlockDiagnosticVMOptions -XX:+UseDilithiumIntrinsics test/jdk/sun/security/provider/acvp/ML_DSA_Intrinsic_Test.java
47+ // To run manually:
48+ // java --add-opens java.base/sun.security.provider=ALL-UNNAMED
49+ // --add-exports java.base/sun.security.provider=ALL-UNNAMED
50+ // -XX:+UnlockDiagnosticVMOptions -XX:+UseDilithiumIntrinsics
51+ // test/jdk/sun/security/provider/pqc/ML_DSA_Intrinsic_Test.java
5152
5253public class ML_DSA_Intrinsic_Test {
5354 public static void main (String [] args ) throws Throwable {
@@ -104,9 +105,10 @@ public static void main(String[] args) throws Throwable {
104105 m .setAccessible (true );
105106 MethodHandle inverseNttJava = lookup .unreflect (m );
106107
107- // Hint: if test fails, you can hardcode the seed to make the test more reproducible
108108 Random rnd = new Random ();
109109 long seed = rnd .nextLong ();
110+ // Hint: if a test fails, it prints the seed, so you can hardcode
111+ // it here to reproduce the failure
110112 rnd .setSeed (seed );
111113 //Note: it might be useful to increase this number during development of new intrinsics
112114 final int repeat = 10000 ;
@@ -117,59 +119,80 @@ public static void main(String[] args) throws Throwable {
117119 int [] prod3 = new int [ML_DSA_N ];
118120 int [] prod4 = new int [ML_DSA_N ];
119121 for (int i = 0 ; i < repeat ; i ++) {
120- // Hint: if test fails, you can hardcode the seed to make the test more reproducible:
121- // rnd.setSeed(seed);
122- testMult (prod1 , prod2 , coeffs1 , coeffs2 , mult , multJava , rnd , seed , i );
122+ testMult (prod1 , prod2 , coeffs1 , coeffs2 ,
123+ mult , multJava , rnd , seed , i );
123124 testMultConst (prod1 , prod2 , multConst , multConstJava , rnd , seed , i );
124- testDecompose (prod1 , prod2 , prod3 , prod4 , coeffs1 , coeffs2 , decompose , decomposeJava , rnd , seed , i );
125+ testDecompose (prod1 , prod2 , prod3 , prod4 , coeffs1 , coeffs2 ,
126+ decompose , decomposeJava , rnd , seed , i );
125127 testAlmostNtt (coeffs1 , coeffs2 , almostNtt , almostNttJava , rnd , seed , i );
126128 testInverseNtt (coeffs1 , coeffs2 , inverseNtt , inverseNttJava , rnd , seed , i );
127129 }
128130 System .out .println ("Fuzz Success" );
129131 }
130132
131- private static final int ML_DSA_N = 256 ;
132- public static void testMult ( int [] prod1 , int [] prod2 , int [] coeffs1 , int [] coeffs2 ,
133+ public static void testMult ( int [] prod1 , int [] prod2 ,
134+ int [] coeffs1 , int [] coeffs2 ,
133135 MethodHandle mult , MethodHandle multJava , Random rnd ,
134136 long seed , int i ) throws Throwable {
135137
136- for (int j = 0 ; j <ML_DSA_N ; j ++) {
137- coeffs1 [j ] = rnd .nextInt ();
138- coeffs2 [j ] = rnd .nextInt ();
138+ // This method is always called with arrays whose elements are between
139+ // -ML_DSA_Q and ML_DSA_Q, so we only test for these here (although
140+ // both versions work fine with array element sizes that satisfy the
141+ // montMul() preconditions in sun.security.provider.ML_DSA.java
142+ for (int j = 0 ; j < ML_DSA_N ; j ++) {
143+ coeffs1 [j ] = rnd .nextInt (2 * ML_DSA_Q ) - ML_DSA_Q ;
144+ coeffs2 [j ] = rnd .nextInt (2 * ML_DSA_Q ) - ML_DSA_Q ;
139145 }
140146
141147 mult .invoke (prod1 , coeffs1 , coeffs2 );
142148 multJava .invoke (prod2 , coeffs1 , coeffs2 );
143149
144150 if (!Arrays .equals (prod1 , prod2 )) {
145- throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result mult mismatch: " + formatOf (prod1 ) + " != " + formatOf (prod2 ));
151+ // The Java version and the intrinsic version should not produce
152+ // the exact same result (although usually they do), it is enough
153+ // if the corresponding array elements are congruent modulo ML_DSA_Q
154+ boolean modQequal = true ;
155+ for (int j = 0 ; j < ML_DSA_N ; j ++) {
156+ if (prod1 [j ] != prod2 [j ]) {
157+ modQequal &= (((prod1 [j ] - prod2 [j ]) % ML_DSA_Q ) == 0 );
158+ }
159+ }
160+ if (!modQequal ) {
161+ throw new RuntimeException ("[Seed " + seed + "@" + i
162+ + "] Result mult mismatch: "
163+ + formatOf (prod1 ) + "\n != " + formatOf (prod2 ));
164+ }
146165 }
147166 }
148167
149168 public static void testMultConst (int [] prod1 , int [] prod2 ,
150169 MethodHandle multConst , MethodHandle multConstJava , Random rnd ,
151170 long seed , int i ) throws Throwable {
152171
153- for (int j = 0 ; j < ML_DSA_N ; j ++) {
172+ for (int j = 0 ; j < ML_DSA_N ; j ++) {
154173 prod1 [j ] = prod2 [j ] = rnd .nextInt ();
155174 }
156- // Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf, one of the inputs is bound, which prevents overflows
157- int dilithium_q = 8380417 ;
158- int c = rnd .nextInt (dilithium_q );
175+
176+ // Per Algorithm 3 in https://eprint.iacr.org/2018/039.pdf,
177+ // one of the inputs is bound, which prevents overflows
178+ int c = rnd .nextInt (ML_DSA_Q );
159179
160180 multConst .invoke (prod1 , c );
161181 multConstJava .invoke (prod2 , c );
162182
163183 if (!Arrays .equals (prod1 , prod2 )) {
164- throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result multConst mismatch: " + formatOf (prod1 ) + " != " + formatOf (prod2 ));
184+ throw new RuntimeException ("[Seed " + seed + "@" + i
185+ + "] Result multConst mismatch: "
186+ + formatOf (prod1 ) + " != " + formatOf (prod2 ));
165187 }
166188 }
167189
168- public static void testDecompose (int [] low1 , int [] high1 , int [] low2 , int [] high2 , int [] coeffs1 , int [] coeffs2 ,
190+ public static void testDecompose (int [] low1 , int [] high1 , int [] low2 ,
191+ int [] high2 , int [] coeffs1 , int [] coeffs2 ,
169192 MethodHandle decompose , MethodHandle decomposeJava , Random rnd ,
170193 long seed , int i ) throws Throwable {
171194
172- for (int j = 0 ; j < ML_DSA_N ; j ++) {
195+ for (int j = 0 ; j < ML_DSA_N ; j ++) {
173196 coeffs1 [j ] = coeffs2 [j ] = rnd .nextInt ();
174197 }
175198 int gamma2 = 95232 ;
@@ -182,41 +205,49 @@ public static void testDecompose(int[] low1, int[] high1, int[] low2, int[] high
182205 decomposeJava .invoke (coeffs2 , low2 , high2 , 2 * gamma2 , multiplier );
183206
184207 if (!Arrays .equals (low1 , low2 )) {
185- throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result low mismatch: " + formatOf (low1 ) + " != " + formatOf (low2 ));
208+ throw new RuntimeException ("[Seed " + seed + "@" + i
209+ + "] Result low mismatch: "
210+ + formatOf (low1 ) + " != " + formatOf (low2 ));
186211 }
187212
188213 if (!Arrays .equals (high1 , high2 )) {
189- throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result high mismatch: " + formatOf (high1 ) + " != " + formatOf (high2 ));
214+ throw new RuntimeException ("[Seed " + seed + "@" + i
215+ + "] Result high mismatch: "
216+ + formatOf (high1 ) + " != " + formatOf (high2 ));
190217 }
191218 }
192219
193220 public static void testAlmostNtt (int [] coeffs1 , int [] coeffs2 ,
194221 MethodHandle almostNtt , MethodHandle almostNttJava , Random rnd ,
195222 long seed , int i ) throws Throwable {
196- for (int j = 0 ; j < ML_DSA_N ; j ++) {
223+ for (int j = 0 ; j < ML_DSA_N ; j ++) {
197224 coeffs1 [j ] = coeffs2 [j ] = rnd .nextInt ();
198225 }
199226
200227 almostNtt .invoke (coeffs1 , MONT_ZETAS_FOR_VECTOR_NTT );
201228 almostNttJava .invoke (coeffs2 );
202229
203230 if (!Arrays .equals (coeffs1 , coeffs2 )) {
204- throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result AlmostNtt mismatch: " + formatOf (coeffs1 ) + " != " + formatOf (coeffs2 ));
231+ throw new RuntimeException ("[Seed " + seed + "@" + i
232+ +"] Result AlmostNtt mismatch: "
233+ + formatOf (coeffs1 ) + " != " + formatOf (coeffs2 ));
205234 }
206235 }
207236
208237 public static void testInverseNtt (int [] coeffs1 , int [] coeffs2 ,
209238 MethodHandle inverseNtt , MethodHandle inverseNttJava , Random rnd ,
210239 long seed , int i ) throws Throwable {
211- for (int j = 0 ; j < ML_DSA_N ; j ++) {
240+ for (int j = 0 ; j < ML_DSA_N ; j ++) {
212241 coeffs1 [j ] = coeffs2 [j ] = rnd .nextInt ();
213242 }
214243
215244 inverseNtt .invoke (coeffs1 , MONT_ZETAS_FOR_VECTOR_INVERSE_NTT );
216245 inverseNttJava .invoke (coeffs2 );
217246
218247 if (!Arrays .equals (coeffs1 , coeffs2 )) {
219- throw new RuntimeException ("[Seed " +seed +"@" +i +"] Result InverseNtt mismatch: " + formatOf (coeffs1 ) + " != " + formatOf (coeffs2 ));
248+ throw new RuntimeException ("[Seed " + seed + "@" + i
249+ +"] Result InverseNtt mismatch: "
250+ + formatOf (coeffs1 ) + " != " + formatOf (coeffs2 ));
220251 }
221252 }
222253
@@ -230,6 +261,9 @@ private static CharSequence formatOf(int[] arr) {
230261 }
231262
232263 // Copied constants from sun.security.provider.ML_DSA
264+ private static final int ML_DSA_N = 256 ;
265+ private static final int ML_DSA_Q = 8380417 ;
266+
233267 private static final int [] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT = new int []{
234268 -1976782 , 846154 , -1400424 , -3937738 , 1362209 , 48306 , -3919660 , 554416 ,
235269 3545687 , -1612842 , 976891 , -183443 , 2286327 , 420899 , 2235985 , 2939036 ,
0 commit comments