diff --git a/src/jmh/java/cafe/cryptography/curve25519/ScalarBench.java b/src/jmh/java/cafe/cryptography/curve25519/ScalarBench.java index 1cc4448..427fa36 100644 --- a/src/jmh/java/cafe/cryptography/curve25519/ScalarBench.java +++ b/src/jmh/java/cafe/cryptography/curve25519/ScalarBench.java @@ -58,6 +58,12 @@ public Scalar multiply() { return this.a.multiply(this.b); } + @Benchmark + public Scalar invert() { return this.a.invert(); } + + @Benchmark + public Scalar divide() { return this.a.divide(this.b); } + @Benchmark public Scalar multiplyAndAddManual() { return this.a.multiply(this.b).add(this.c); diff --git a/src/main/java/cafe/cryptography/curve25519/Constants.java b/src/main/java/cafe/cryptography/curve25519/Constants.java index c31bb66..ce0e5f7 100644 --- a/src/main/java/cafe/cryptography/curve25519/Constants.java +++ b/src/main/java/cafe/cryptography/curve25519/Constants.java @@ -6,6 +6,8 @@ package cafe.cryptography.curve25519; +import java.math.BigInteger; + /** * Various constants and useful parameters. */ @@ -154,4 +156,10 @@ public final class Constants { */ public static final RistrettoGeneratorTable RISTRETTO_GENERATOR_TABLE = new RistrettoGeneratorTable( RISTRETTO_GENERATOR); + + /** + * Value of ed25519 base point - 2 + * Used for calculating the modular inverse of a Scalar + */ + public static final BigInteger ModMinus2 = new BigInteger("7237005577332262213973186563042994240857116359379907606001950938285454250987"); } diff --git a/src/main/java/cafe/cryptography/curve25519/Scalar.java b/src/main/java/cafe/cryptography/curve25519/Scalar.java index a34a644..6a06f16 100644 --- a/src/main/java/cafe/cryptography/curve25519/Scalar.java +++ b/src/main/java/cafe/cryptography/curve25519/Scalar.java @@ -832,6 +832,328 @@ public Scalar multiplyAndAdd(Scalar b, Scalar c) { return new Scalar(result); } + public Scalar square() { + long a0 = 0x1FFFFF & load_3(this.s, 0); + long a1 = 0x1FFFFF & (load_4(this.s, 2) >> 5); + long a2 = 0x1FFFFF & (load_3(this.s, 5) >> 2); + long a3 = 0x1FFFFF & (load_4(this.s, 7) >> 7); + long a4 = 0x1FFFFF & (load_4(this.s, 10) >> 4); + long a5 = 0x1FFFFF & (load_3(this.s, 13) >> 1); + long a6 = 0x1FFFFF & (load_4(this.s, 15) >> 6); + long a7 = 0x1FFFFF & (load_3(this.s, 18) >> 3); + long a8 = 0x1FFFFF & load_3(this.s, 21); + long a9 = 0x1FFFFF & (load_4(this.s, 23) >> 5); + long a10 = 0x1FFFFF & (load_3(this.s, 26) >> 2); + long a11 = (load_4(this.s, 28) >> 7); + long s0; + long s1; + long s2; + long s3; + long s4; + long s5; + long s6; + long s7; + long s8; + long s9; + long s10; + long s11; + long s12; + long s13; + long s14; + long s15; + long s16; + long s17; + long s18; + long s19; + long s20; + long s21; + long s22; + long s23; + long carry0; + long carry1; + long carry2; + long carry3; + long carry4; + long carry5; + long carry6; + long carry7; + long carry8; + long carry9; + long carry10; + long carry11; + long carry12; + long carry13; + long carry14; + long carry15; + long carry16; + long carry17; + long carry18; + long carry19; + long carry20; + long carry21; + long carry22; + + // @formatter:off + s0 = a0*a0; + s1 = a0*a1 + a1*a0; + s2 = a0*a2 + a1*a1 + a2*a0; + s3 = a0*a3 + a1*a2 + a2*a1 + a3*a0; + s4 = a0*a4 + a1*a3 + a2*a2 + a3*a1 + a4*a0; + s5 = a0*a5 + a1*a4 + a2*a3 + a3*a2 + a4*a1 + a5*a0; + s6 = a0*a6 + a1*a5 + a2*a4 + a3*a3 + a4*a2 + a5*a1 + a6*a0; + s7 = a0*a7 + a1*a6 + a2*a5 + a3*a4 + a4*a3 + a5*a2 + a6*a1 + a7*a0; + s8 = a0*a8 + a1*a7 + a2*a6 + a3*a5 + a4*a4 + a5*a3 + a6*a2 + a7*a1 + a8*a0; + s9 = a0*a9 + a1*a8 + a2*a7 + a3*a6 + a4*a5 + a5*a4 + a6*a3 + a7*a2 + a8*a1 + a9*a0; + s10 = a0*a10 + a1*a9 + a2*a8 + a3*a7 + a4*a6 + a5*a5 + a6*a4 + a7*a3 + a8*a2 + a9*a1 + a10*a0; + s11 = a0*a11 + a1*a10 + a2*a9 + a3*a8 + a4*a7 + a5*a6 + a6*a5 + a7*a4 + a8*a3 + a9*a2 + a10*a1 + a11*a0; + s12 = a1*a11 + a2*a10 + a3*a9 + a4*a8 + a5*a7 + a6*a6 + a7*a5 + a8*a4 + a9*a3 + a10*a2 + a11*a1; + s13 = a2*a11 + a3*a10 + a4*a9 + a5*a8 + a6*a7 + a7*a6 + a8*a5 + a9*a4 + a10*a3 + a11*a2; + s14 = a3*a11 + a4*a10 + a5*a9 + a6*a8 + a7*a7 + a8*a6 + a9*a5 + a10*a4 + a11*a3; + s15 = a4*a11 + a5*a10 + a6*a9 + a7*a8 + a8*a7 + a9*a6 + a10*a5 + a11*a4; + s16 = a5*a11 + a6*a10 + a7*a9 + a8*a8 + a9*a7 + a10*a6 + a11*a5; + s17 = a6*a11 + a7*a10 + a8*a9 + a9*a8 + a10*a7 + a11*a6; + s18 = a7*a11 + a8*a10 + a9*a9 + a10*a8 + a11*a7; + s19 = a8*a11 + a9*a10 + a10*a9 + a11*a8; + s20 = a9*a11 + a10*a10 + a11*a9; + s21 = a10*a11 + a11*a10; + s22 = a11*a11; + + carry0 = (s0 + (1<<20)) >> 21; s1 += carry0; s0 -= carry0 << 21; + carry2 = (s2 + (1<<20)) >> 21; s3 += carry2; s2 -= carry2 << 21; + carry4 = (s4 + (1<<20)) >> 21; s5 += carry4; s4 -= carry4 << 21; + carry6 = (s6 + (1<<20)) >> 21; s7 += carry6; s6 -= carry6 << 21; + carry8 = (s8 + (1<<20)) >> 21; s9 += carry8; s8 -= carry8 << 21; + carry10 = (s10 + (1<<20)) >> 21; s11 += carry10; s10 -= carry10 << 21; + carry12 = (s12 + (1<<20)) >> 21; s13 += carry12; s12 -= carry12 << 21; + carry14 = (s14 + (1<<20)) >> 21; s15 += carry14; s14 -= carry14 << 21; + carry16 = (s16 + (1<<20)) >> 21; s17 += carry16; s16 -= carry16 << 21; + carry18 = (s18 + (1<<20)) >> 21; s19 += carry18; s18 -= carry18 << 21; + carry20 = (s20 + (1<<20)) >> 21; s21 += carry20; s20 -= carry20 << 21; + carry22 = (s22 + (1<<20)) >> 21; s23 = carry22; s22 -= carry22 << 21; + + carry1 = (s1 + (1<<20)) >> 21; s2 += carry1; s1 -= carry1 << 21; + carry3 = (s3 + (1<<20)) >> 21; s4 += carry3; s3 -= carry3 << 21; + carry5 = (s5 + (1<<20)) >> 21; s6 += carry5; s5 -= carry5 << 21; + carry7 = (s7 + (1<<20)) >> 21; s8 += carry7; s7 -= carry7 << 21; + carry9 = (s9 + (1<<20)) >> 21; s10 += carry9; s9 -= carry9 << 21; + carry11 = (s11 + (1<<20)) >> 21; s12 += carry11; s11 -= carry11 << 21; + carry13 = (s13 + (1<<20)) >> 21; s14 += carry13; s13 -= carry13 << 21; + carry15 = (s15 + (1<<20)) >> 21; s16 += carry15; s15 -= carry15 << 21; + carry17 = (s17 + (1<<20)) >> 21; s18 += carry17; s17 -= carry17 << 21; + carry19 = (s19 + (1<<20)) >> 21; s20 += carry19; s19 -= carry19 << 21; + carry21 = (s21 + (1<<20)) >> 21; s22 += carry21; s21 -= carry21 << 21; + // @formatter:on + + s11 += s23 * 666643; + s12 += s23 * 470296; + s13 += s23 * 654183; + s14 -= s23 * 997805; + s15 += s23 * 136657; + s16 -= s23 * 683901; + + s10 += s22 * 666643; + s11 += s22 * 470296; + s12 += s22 * 654183; + s13 -= s22 * 997805; + s14 += s22 * 136657; + s15 -= s22 * 683901; + + s9 += s21 * 666643; + s10 += s21 * 470296; + s11 += s21 * 654183; + s12 -= s21 * 997805; + s13 += s21 * 136657; + s14 -= s21 * 683901; + + s8 += s20 * 666643; + s9 += s20 * 470296; + s10 += s20 * 654183; + s11 -= s20 * 997805; + s12 += s20 * 136657; + s13 -= s20 * 683901; + + s7 += s19 * 666643; + s8 += s19 * 470296; + s9 += s19 * 654183; + s10 -= s19 * 997805; + s11 += s19 * 136657; + s12 -= s19 * 683901; + + s6 += s18 * 666643; + s7 += s18 * 470296; + s8 += s18 * 654183; + s9 -= s18 * 997805; + s10 += s18 * 136657; + s11 -= s18 * 683901; + + // @formatter:off + carry6 = (s6 + (1<<20)) >> 21; s7 += carry6; s6 -= carry6 << 21; + carry8 = (s8 + (1<<20)) >> 21; s9 += carry8; s8 -= carry8 << 21; + carry10 = (s10 + (1<<20)) >> 21; s11 += carry10; s10 -= carry10 << 21; + carry12 = (s12 + (1<<20)) >> 21; s13 += carry12; s12 -= carry12 << 21; + carry14 = (s14 + (1<<20)) >> 21; s15 += carry14; s14 -= carry14 << 21; + carry16 = (s16 + (1<<20)) >> 21; s17 += carry16; s16 -= carry16 << 21; + + carry7 = (s7 + (1<<20)) >> 21; s8 += carry7; s7 -= carry7 << 21; + carry9 = (s9 + (1<<20)) >> 21; s10 += carry9; s9 -= carry9 << 21; + carry11 = (s11 + (1<<20)) >> 21; s12 += carry11; s11 -= carry11 << 21; + carry13 = (s13 + (1<<20)) >> 21; s14 += carry13; s13 -= carry13 << 21; + carry15 = (s15 + (1<<20)) >> 21; s16 += carry15; s15 -= carry15 << 21; + // @formatter:on + + s5 += s17 * 666643; + s6 += s17 * 470296; + s7 += s17 * 654183; + s8 -= s17 * 997805; + s9 += s17 * 136657; + s10 -= s17 * 683901; + + s4 += s16 * 666643; + s5 += s16 * 470296; + s6 += s16 * 654183; + s7 -= s16 * 997805; + s8 += s16 * 136657; + s9 -= s16 * 683901; + + s3 += s15 * 666643; + s4 += s15 * 470296; + s5 += s15 * 654183; + s6 -= s15 * 997805; + s7 += s15 * 136657; + s8 -= s15 * 683901; + + s2 += s14 * 666643; + s3 += s14 * 470296; + s4 += s14 * 654183; + s5 -= s14 * 997805; + s6 += s14 * 136657; + s7 -= s14 * 683901; + + s1 += s13 * 666643; + s2 += s13 * 470296; + s3 += s13 * 654183; + s4 -= s13 * 997805; + s5 += s13 * 136657; + s6 -= s13 * 683901; + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + + // @formatter:off + carry0 = (s0 + (1<<20)) >> 21; s1 += carry0; s0 -= carry0 << 21; + carry2 = (s2 + (1<<20)) >> 21; s3 += carry2; s2 -= carry2 << 21; + carry4 = (s4 + (1<<20)) >> 21; s5 += carry4; s4 -= carry4 << 21; + carry6 = (s6 + (1<<20)) >> 21; s7 += carry6; s6 -= carry6 << 21; + carry8 = (s8 + (1<<20)) >> 21; s9 += carry8; s8 -= carry8 << 21; + carry10 = (s10 + (1<<20)) >> 21; s11 += carry10; s10 -= carry10 << 21; + + carry1 = (s1 + (1<<20)) >> 21; s2 += carry1; s1 -= carry1 << 21; + carry3 = (s3 + (1<<20)) >> 21; s4 += carry3; s3 -= carry3 << 21; + carry5 = (s5 + (1<<20)) >> 21; s6 += carry5; s5 -= carry5 << 21; + carry7 = (s7 + (1<<20)) >> 21; s8 += carry7; s7 -= carry7 << 21; + carry9 = (s9 + (1<<20)) >> 21; s10 += carry9; s9 -= carry9 << 21; + carry11 = (s11 + (1<<20)) >> 21; s12 = carry11; s11 -= carry11 << 21; + // @formatter:on + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + + // @formatter:off + carry0 = s0 >> 21; s1 += carry0; s0 -= carry0 << 21; + carry1 = s1 >> 21; s2 += carry1; s1 -= carry1 << 21; + carry2 = s2 >> 21; s3 += carry2; s2 -= carry2 << 21; + carry3 = s3 >> 21; s4 += carry3; s3 -= carry3 << 21; + carry4 = s4 >> 21; s5 += carry4; s4 -= carry4 << 21; + carry5 = s5 >> 21; s6 += carry5; s5 -= carry5 << 21; + carry6 = s6 >> 21; s7 += carry6; s6 -= carry6 << 21; + carry7 = s7 >> 21; s8 += carry7; s7 -= carry7 << 21; + carry8 = s8 >> 21; s9 += carry8; s8 -= carry8 << 21; + carry9 = s9 >> 21; s10 += carry9; s9 -= carry9 << 21; + carry10 = s10 >> 21; s11 += carry10; s10 -= carry10 << 21; + carry11 = s11 >> 21; s12 = carry11; s11 -= carry11 << 21; + // @formatter:on + + s0 += s12 * 666643; + s1 += s12 * 470296; + s2 += s12 * 654183; + s3 -= s12 * 997805; + s4 += s12 * 136657; + s5 -= s12 * 683901; + + // @formatter:off + carry0 = s0 >> 21; s1 += carry0; s0 -= carry0 << 21; + carry1 = s1 >> 21; s2 += carry1; s1 -= carry1 << 21; + carry2 = s2 >> 21; s3 += carry2; s2 -= carry2 << 21; + carry3 = s3 >> 21; s4 += carry3; s3 -= carry3 << 21; + carry4 = s4 >> 21; s5 += carry4; s4 -= carry4 << 21; + carry5 = s5 >> 21; s6 += carry5; s5 -= carry5 << 21; + carry6 = s6 >> 21; s7 += carry6; s6 -= carry6 << 21; + carry7 = s7 >> 21; s8 += carry7; s7 -= carry7 << 21; + carry8 = s8 >> 21; s9 += carry8; s8 -= carry8 << 21; + carry9 = s9 >> 21; s10 += carry9; s9 -= carry9 << 21; + carry10 = s10 >> 21; s11 += carry10; s10 -= carry10 << 21; + // @formatter:on + + byte[] result = new byte[32]; + result[0] = (byte) s0; + result[1] = (byte) (s0 >> 8); + result[2] = (byte) ((s0 >> 16) | (s1 << 5)); + result[3] = (byte) (s1 >> 3); + result[4] = (byte) (s1 >> 11); + result[5] = (byte) ((s1 >> 19) | (s2 << 2)); + result[6] = (byte) (s2 >> 6); + result[7] = (byte) ((s2 >> 14) | (s3 << 7)); + result[8] = (byte) (s3 >> 1); + result[9] = (byte) (s3 >> 9); + result[10] = (byte) ((s3 >> 17) | (s4 << 4)); + result[11] = (byte) (s4 >> 4); + result[12] = (byte) (s4 >> 12); + result[13] = (byte) ((s4 >> 20) | (s5 << 1)); + result[14] = (byte) (s5 >> 7); + result[15] = (byte) ((s5 >> 15) | (s6 << 6)); + result[16] = (byte) (s6 >> 2); + result[17] = (byte) (s6 >> 10); + result[18] = (byte) ((s6 >> 18) | (s7 << 3)); + result[19] = (byte) (s7 >> 5); + result[20] = (byte) (s7 >> 13); + result[21] = (byte) s8; + result[22] = (byte) (s8 >> 8); + result[23] = (byte) ((s8 >> 16) | (s9 << 5)); + result[24] = (byte) (s9 >> 3); + result[25] = (byte) (s9 >> 11); + result[26] = (byte) ((s9 >> 19) | (s10 << 2)); + result[27] = (byte) (s10 >> 6); + result[28] = (byte) ((s10 >> 14) | (s11 << 7)); + result[29] = (byte) (s11 >> 1); + result[30] = (byte) (s11 >> 9); + result[31] = (byte) (s11 >> 17); + return new Scalar(result); + } + + public Scalar invert() { + Scalar res = Scalar.ONE; + for(int i = 255; i >= 0; i--){ + res = res.square(); + if (Constants.ModMinus2.testBit(i)) { + res = this.multiplyAndAdd(res, Scalar.ZERO); + } + } + return res; + } + + public Scalar divide(Scalar a) { + Scalar inv_a = a.invert(); + return this.multiplyAndAdd(inv_a, Scalar.ZERO); + } + /** * Writes this Scalar in radix 16, with coefficients in range $[-8, 8)$. * diff --git a/src/test/java/cafe/cryptography/curve25519/ScalarTest.java b/src/test/java/cafe/cryptography/curve25519/ScalarTest.java index 23fd22b..5de0eb4 100644 --- a/src/test/java/cafe/cryptography/curve25519/ScalarTest.java +++ b/src/test/java/cafe/cryptography/curve25519/ScalarTest.java @@ -238,6 +238,16 @@ public void multiply() { assertThat(X_TIMES_Y.multiply(XINV), is(Y)); } + @Test + public void inverse() { + assertThat(X.invert(), is(XINV)); + } + + @Test + public void divide() { + assertThat(Y.divide(X), is(Y.multiply(XINV))); + } + @Test public void nonAdjacentForm() { byte[] naf = A_SCALAR.nonAdjacentForm();