Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar division and invert functions added #38

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/jmh/java/cafe/cryptography/curve25519/ScalarBench.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ public Scalar multiply() {
return this.a.multiply(this.b);
}

@Benchmark
public Scalar invert() { return this.a.invert(); }

@Benchmark
public Scalar divide() { return this.a.divide(this.b); }

@Benchmark
public Scalar multiplyAndAddManual() {
return this.a.multiply(this.b).add(this.c);
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/cafe/cryptography/curve25519/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package cafe.cryptography.curve25519;

import java.math.BigInteger;

/**
* Various constants and useful parameters.
*/
Expand Down Expand Up @@ -154,4 +156,10 @@ public final class Constants {
*/
public static final RistrettoGeneratorTable RISTRETTO_GENERATOR_TABLE = new RistrettoGeneratorTable(
RISTRETTO_GENERATOR);

/**
* Value of ed25519 base point - 2
* Used for calculating the modular inverse of a Scalar
*/
public static final BigInteger ModMinus2 = new BigInteger("7237005577332262213973186563042994240857116359379907606001950938285454250987");
}
322 changes: 322 additions & 0 deletions src/main/java/cafe/cryptography/curve25519/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,328 @@ public Scalar multiplyAndAdd(Scalar b, Scalar c) {
return new Scalar(result);
}

public Scalar square() {
long a0 = 0x1FFFFF & load_3(this.s, 0);
long a1 = 0x1FFFFF & (load_4(this.s, 2) >> 5);
long a2 = 0x1FFFFF & (load_3(this.s, 5) >> 2);
long a3 = 0x1FFFFF & (load_4(this.s, 7) >> 7);
long a4 = 0x1FFFFF & (load_4(this.s, 10) >> 4);
long a5 = 0x1FFFFF & (load_3(this.s, 13) >> 1);
long a6 = 0x1FFFFF & (load_4(this.s, 15) >> 6);
long a7 = 0x1FFFFF & (load_3(this.s, 18) >> 3);
long a8 = 0x1FFFFF & load_3(this.s, 21);
long a9 = 0x1FFFFF & (load_4(this.s, 23) >> 5);
long a10 = 0x1FFFFF & (load_3(this.s, 26) >> 2);
long a11 = (load_4(this.s, 28) >> 7);
long s0;
long s1;
long s2;
long s3;
long s4;
long s5;
long s6;
long s7;
long s8;
long s9;
long s10;
long s11;
long s12;
long s13;
long s14;
long s15;
long s16;
long s17;
long s18;
long s19;
long s20;
long s21;
long s22;
long s23;
long carry0;
long carry1;
long carry2;
long carry3;
long carry4;
long carry5;
long carry6;
long carry7;
long carry8;
long carry9;
long carry10;
long carry11;
long carry12;
long carry13;
long carry14;
long carry15;
long carry16;
long carry17;
long carry18;
long carry19;
long carry20;
long carry21;
long carry22;

// @formatter:off
s0 = a0*a0;
s1 = a0*a1 + a1*a0;
s2 = a0*a2 + a1*a1 + a2*a0;
s3 = a0*a3 + a1*a2 + a2*a1 + a3*a0;
s4 = a0*a4 + a1*a3 + a2*a2 + a3*a1 + a4*a0;
s5 = a0*a5 + a1*a4 + a2*a3 + a3*a2 + a4*a1 + a5*a0;
s6 = a0*a6 + a1*a5 + a2*a4 + a3*a3 + a4*a2 + a5*a1 + a6*a0;
s7 = a0*a7 + a1*a6 + a2*a5 + a3*a4 + a4*a3 + a5*a2 + a6*a1 + a7*a0;
s8 = a0*a8 + a1*a7 + a2*a6 + a3*a5 + a4*a4 + a5*a3 + a6*a2 + a7*a1 + a8*a0;
s9 = a0*a9 + a1*a8 + a2*a7 + a3*a6 + a4*a5 + a5*a4 + a6*a3 + a7*a2 + a8*a1 + a9*a0;
s10 = a0*a10 + a1*a9 + a2*a8 + a3*a7 + a4*a6 + a5*a5 + a6*a4 + a7*a3 + a8*a2 + a9*a1 + a10*a0;
s11 = a0*a11 + a1*a10 + a2*a9 + a3*a8 + a4*a7 + a5*a6 + a6*a5 + a7*a4 + a8*a3 + a9*a2 + a10*a1 + a11*a0;
s12 = a1*a11 + a2*a10 + a3*a9 + a4*a8 + a5*a7 + a6*a6 + a7*a5 + a8*a4 + a9*a3 + a10*a2 + a11*a1;
s13 = a2*a11 + a3*a10 + a4*a9 + a5*a8 + a6*a7 + a7*a6 + a8*a5 + a9*a4 + a10*a3 + a11*a2;
s14 = a3*a11 + a4*a10 + a5*a9 + a6*a8 + a7*a7 + a8*a6 + a9*a5 + a10*a4 + a11*a3;
s15 = a4*a11 + a5*a10 + a6*a9 + a7*a8 + a8*a7 + a9*a6 + a10*a5 + a11*a4;
s16 = a5*a11 + a6*a10 + a7*a9 + a8*a8 + a9*a7 + a10*a6 + a11*a5;
s17 = a6*a11 + a7*a10 + a8*a9 + a9*a8 + a10*a7 + a11*a6;
s18 = a7*a11 + a8*a10 + a9*a9 + a10*a8 + a11*a7;
s19 = a8*a11 + a9*a10 + a10*a9 + a11*a8;
s20 = a9*a11 + a10*a10 + a11*a9;
s21 = a10*a11 + a11*a10;
s22 = a11*a11;

carry0 = (s0 + (1<<20)) >> 21; s1 += carry0; s0 -= carry0 << 21;
carry2 = (s2 + (1<<20)) >> 21; s3 += carry2; s2 -= carry2 << 21;
carry4 = (s4 + (1<<20)) >> 21; s5 += carry4; s4 -= carry4 << 21;
carry6 = (s6 + (1<<20)) >> 21; s7 += carry6; s6 -= carry6 << 21;
carry8 = (s8 + (1<<20)) >> 21; s9 += carry8; s8 -= carry8 << 21;
carry10 = (s10 + (1<<20)) >> 21; s11 += carry10; s10 -= carry10 << 21;
carry12 = (s12 + (1<<20)) >> 21; s13 += carry12; s12 -= carry12 << 21;
carry14 = (s14 + (1<<20)) >> 21; s15 += carry14; s14 -= carry14 << 21;
carry16 = (s16 + (1<<20)) >> 21; s17 += carry16; s16 -= carry16 << 21;
carry18 = (s18 + (1<<20)) >> 21; s19 += carry18; s18 -= carry18 << 21;
carry20 = (s20 + (1<<20)) >> 21; s21 += carry20; s20 -= carry20 << 21;
carry22 = (s22 + (1<<20)) >> 21; s23 = carry22; s22 -= carry22 << 21;

carry1 = (s1 + (1<<20)) >> 21; s2 += carry1; s1 -= carry1 << 21;
carry3 = (s3 + (1<<20)) >> 21; s4 += carry3; s3 -= carry3 << 21;
carry5 = (s5 + (1<<20)) >> 21; s6 += carry5; s5 -= carry5 << 21;
carry7 = (s7 + (1<<20)) >> 21; s8 += carry7; s7 -= carry7 << 21;
carry9 = (s9 + (1<<20)) >> 21; s10 += carry9; s9 -= carry9 << 21;
carry11 = (s11 + (1<<20)) >> 21; s12 += carry11; s11 -= carry11 << 21;
carry13 = (s13 + (1<<20)) >> 21; s14 += carry13; s13 -= carry13 << 21;
carry15 = (s15 + (1<<20)) >> 21; s16 += carry15; s15 -= carry15 << 21;
carry17 = (s17 + (1<<20)) >> 21; s18 += carry17; s17 -= carry17 << 21;
carry19 = (s19 + (1<<20)) >> 21; s20 += carry19; s19 -= carry19 << 21;
carry21 = (s21 + (1<<20)) >> 21; s22 += carry21; s21 -= carry21 << 21;
// @formatter:on

s11 += s23 * 666643;
s12 += s23 * 470296;
s13 += s23 * 654183;
s14 -= s23 * 997805;
s15 += s23 * 136657;
s16 -= s23 * 683901;

s10 += s22 * 666643;
s11 += s22 * 470296;
s12 += s22 * 654183;
s13 -= s22 * 997805;
s14 += s22 * 136657;
s15 -= s22 * 683901;

s9 += s21 * 666643;
s10 += s21 * 470296;
s11 += s21 * 654183;
s12 -= s21 * 997805;
s13 += s21 * 136657;
s14 -= s21 * 683901;

s8 += s20 * 666643;
s9 += s20 * 470296;
s10 += s20 * 654183;
s11 -= s20 * 997805;
s12 += s20 * 136657;
s13 -= s20 * 683901;

s7 += s19 * 666643;
s8 += s19 * 470296;
s9 += s19 * 654183;
s10 -= s19 * 997805;
s11 += s19 * 136657;
s12 -= s19 * 683901;

s6 += s18 * 666643;
s7 += s18 * 470296;
s8 += s18 * 654183;
s9 -= s18 * 997805;
s10 += s18 * 136657;
s11 -= s18 * 683901;

// @formatter:off
carry6 = (s6 + (1<<20)) >> 21; s7 += carry6; s6 -= carry6 << 21;
carry8 = (s8 + (1<<20)) >> 21; s9 += carry8; s8 -= carry8 << 21;
carry10 = (s10 + (1<<20)) >> 21; s11 += carry10; s10 -= carry10 << 21;
carry12 = (s12 + (1<<20)) >> 21; s13 += carry12; s12 -= carry12 << 21;
carry14 = (s14 + (1<<20)) >> 21; s15 += carry14; s14 -= carry14 << 21;
carry16 = (s16 + (1<<20)) >> 21; s17 += carry16; s16 -= carry16 << 21;

carry7 = (s7 + (1<<20)) >> 21; s8 += carry7; s7 -= carry7 << 21;
carry9 = (s9 + (1<<20)) >> 21; s10 += carry9; s9 -= carry9 << 21;
carry11 = (s11 + (1<<20)) >> 21; s12 += carry11; s11 -= carry11 << 21;
carry13 = (s13 + (1<<20)) >> 21; s14 += carry13; s13 -= carry13 << 21;
carry15 = (s15 + (1<<20)) >> 21; s16 += carry15; s15 -= carry15 << 21;
// @formatter:on

s5 += s17 * 666643;
s6 += s17 * 470296;
s7 += s17 * 654183;
s8 -= s17 * 997805;
s9 += s17 * 136657;
s10 -= s17 * 683901;

s4 += s16 * 666643;
s5 += s16 * 470296;
s6 += s16 * 654183;
s7 -= s16 * 997805;
s8 += s16 * 136657;
s9 -= s16 * 683901;

s3 += s15 * 666643;
s4 += s15 * 470296;
s5 += s15 * 654183;
s6 -= s15 * 997805;
s7 += s15 * 136657;
s8 -= s15 * 683901;

s2 += s14 * 666643;
s3 += s14 * 470296;
s4 += s14 * 654183;
s5 -= s14 * 997805;
s6 += s14 * 136657;
s7 -= s14 * 683901;

s1 += s13 * 666643;
s2 += s13 * 470296;
s3 += s13 * 654183;
s4 -= s13 * 997805;
s5 += s13 * 136657;
s6 -= s13 * 683901;

s0 += s12 * 666643;
s1 += s12 * 470296;
s2 += s12 * 654183;
s3 -= s12 * 997805;
s4 += s12 * 136657;
s5 -= s12 * 683901;

// @formatter:off
carry0 = (s0 + (1<<20)) >> 21; s1 += carry0; s0 -= carry0 << 21;
carry2 = (s2 + (1<<20)) >> 21; s3 += carry2; s2 -= carry2 << 21;
carry4 = (s4 + (1<<20)) >> 21; s5 += carry4; s4 -= carry4 << 21;
carry6 = (s6 + (1<<20)) >> 21; s7 += carry6; s6 -= carry6 << 21;
carry8 = (s8 + (1<<20)) >> 21; s9 += carry8; s8 -= carry8 << 21;
carry10 = (s10 + (1<<20)) >> 21; s11 += carry10; s10 -= carry10 << 21;

carry1 = (s1 + (1<<20)) >> 21; s2 += carry1; s1 -= carry1 << 21;
carry3 = (s3 + (1<<20)) >> 21; s4 += carry3; s3 -= carry3 << 21;
carry5 = (s5 + (1<<20)) >> 21; s6 += carry5; s5 -= carry5 << 21;
carry7 = (s7 + (1<<20)) >> 21; s8 += carry7; s7 -= carry7 << 21;
carry9 = (s9 + (1<<20)) >> 21; s10 += carry9; s9 -= carry9 << 21;
carry11 = (s11 + (1<<20)) >> 21; s12 = carry11; s11 -= carry11 << 21;
// @formatter:on

s0 += s12 * 666643;
s1 += s12 * 470296;
s2 += s12 * 654183;
s3 -= s12 * 997805;
s4 += s12 * 136657;
s5 -= s12 * 683901;

// @formatter:off
carry0 = s0 >> 21; s1 += carry0; s0 -= carry0 << 21;
carry1 = s1 >> 21; s2 += carry1; s1 -= carry1 << 21;
carry2 = s2 >> 21; s3 += carry2; s2 -= carry2 << 21;
carry3 = s3 >> 21; s4 += carry3; s3 -= carry3 << 21;
carry4 = s4 >> 21; s5 += carry4; s4 -= carry4 << 21;
carry5 = s5 >> 21; s6 += carry5; s5 -= carry5 << 21;
carry6 = s6 >> 21; s7 += carry6; s6 -= carry6 << 21;
carry7 = s7 >> 21; s8 += carry7; s7 -= carry7 << 21;
carry8 = s8 >> 21; s9 += carry8; s8 -= carry8 << 21;
carry9 = s9 >> 21; s10 += carry9; s9 -= carry9 << 21;
carry10 = s10 >> 21; s11 += carry10; s10 -= carry10 << 21;
carry11 = s11 >> 21; s12 = carry11; s11 -= carry11 << 21;
// @formatter:on

s0 += s12 * 666643;
s1 += s12 * 470296;
s2 += s12 * 654183;
s3 -= s12 * 997805;
s4 += s12 * 136657;
s5 -= s12 * 683901;

// @formatter:off
carry0 = s0 >> 21; s1 += carry0; s0 -= carry0 << 21;
carry1 = s1 >> 21; s2 += carry1; s1 -= carry1 << 21;
carry2 = s2 >> 21; s3 += carry2; s2 -= carry2 << 21;
carry3 = s3 >> 21; s4 += carry3; s3 -= carry3 << 21;
carry4 = s4 >> 21; s5 += carry4; s4 -= carry4 << 21;
carry5 = s5 >> 21; s6 += carry5; s5 -= carry5 << 21;
carry6 = s6 >> 21; s7 += carry6; s6 -= carry6 << 21;
carry7 = s7 >> 21; s8 += carry7; s7 -= carry7 << 21;
carry8 = s8 >> 21; s9 += carry8; s8 -= carry8 << 21;
carry9 = s9 >> 21; s10 += carry9; s9 -= carry9 << 21;
carry10 = s10 >> 21; s11 += carry10; s10 -= carry10 << 21;
// @formatter:on

byte[] result = new byte[32];
result[0] = (byte) s0;
result[1] = (byte) (s0 >> 8);
result[2] = (byte) ((s0 >> 16) | (s1 << 5));
result[3] = (byte) (s1 >> 3);
result[4] = (byte) (s1 >> 11);
result[5] = (byte) ((s1 >> 19) | (s2 << 2));
result[6] = (byte) (s2 >> 6);
result[7] = (byte) ((s2 >> 14) | (s3 << 7));
result[8] = (byte) (s3 >> 1);
result[9] = (byte) (s3 >> 9);
result[10] = (byte) ((s3 >> 17) | (s4 << 4));
result[11] = (byte) (s4 >> 4);
result[12] = (byte) (s4 >> 12);
result[13] = (byte) ((s4 >> 20) | (s5 << 1));
result[14] = (byte) (s5 >> 7);
result[15] = (byte) ((s5 >> 15) | (s6 << 6));
result[16] = (byte) (s6 >> 2);
result[17] = (byte) (s6 >> 10);
result[18] = (byte) ((s6 >> 18) | (s7 << 3));
result[19] = (byte) (s7 >> 5);
result[20] = (byte) (s7 >> 13);
result[21] = (byte) s8;
result[22] = (byte) (s8 >> 8);
result[23] = (byte) ((s8 >> 16) | (s9 << 5));
result[24] = (byte) (s9 >> 3);
result[25] = (byte) (s9 >> 11);
result[26] = (byte) ((s9 >> 19) | (s10 << 2));
result[27] = (byte) (s10 >> 6);
result[28] = (byte) ((s10 >> 14) | (s11 << 7));
result[29] = (byte) (s11 >> 1);
result[30] = (byte) (s11 >> 9);
result[31] = (byte) (s11 >> 17);
return new Scalar(result);
}

public Scalar invert() {
Scalar res = Scalar.ONE;
for(int i = 255; i >= 0; i--){
res = res.square();
if (Constants.ModMinus2.testBit(i)) {
res = this.multiplyAndAdd(res, Scalar.ZERO);
}
}
return res;
}

public Scalar divide(Scalar a) {
Scalar inv_a = a.invert();
return this.multiplyAndAdd(inv_a, Scalar.ZERO);
}

/**
* Writes this Scalar in radix 16, with coefficients in range $[-8, 8)$.
*
Expand Down
10 changes: 10 additions & 0 deletions src/test/java/cafe/cryptography/curve25519/ScalarTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,16 @@ public void multiply() {
assertThat(X_TIMES_Y.multiply(XINV), is(Y));
}

@Test
public void inverse() {
assertThat(X.invert(), is(XINV));
}

@Test
public void divide() {
assertThat(Y.divide(X), is(Y.multiply(XINV)));
}

@Test
public void nonAdjacentForm() {
byte[] naf = A_SCALAR.nonAdjacentForm();
Expand Down