From 0e53b81bd915b9c1e507587926a789681a9fbb1c Mon Sep 17 00:00:00 2001 From: armfazh Date: Tue, 6 Jul 2021 12:36:45 -0700 Subject: [PATCH] Resolving suggestions after code review. --- pke/kyber/internal/common/field_test.go | 12 ++++++------ pke/kyber/internal/common/ntt_test.go | 8 ++++---- pke/kyber/internal/common/poly_test.go | 9 +++++---- sign/dilithium/internal/common/field_test.go | 14 ++++++++------ sign/dilithium/internal/common/ntt_test.go | 6 ++---- sign/dilithium/internal/common/pack_test.go | 12 +++--------- 6 files changed, 28 insertions(+), 33 deletions(-) diff --git a/pke/kyber/internal/common/field_test.go b/pke/kyber/internal/common/field_test.go index 1fd0d9c9..551ff5d3 100644 --- a/pke/kyber/internal/common/field_test.go +++ b/pke/kyber/internal/common/field_test.go @@ -36,29 +36,29 @@ func TestBarrettReduceFull(t *testing.T) { } } -func randSliceUint32(N uint) []uint32 { - bytes := make([]uint8, 4*N) +func randSliceUint32WithMax(length uint, max uint32) []uint32 { + bytes := make([]uint8, 4*length) n, err := rand.Read(bytes) if err != nil { panic(err) } else if n < len(bytes) { panic("short read from RNG") } - x := make([]uint32, N) + x := make([]uint32, length) for i := range x { - x[i] = binary.LittleEndian.Uint32(bytes[4*i:]) + x[i] = binary.LittleEndian.Uint32(bytes[4*i:]) % max } return x } func TestMontReduce(t *testing.T) { N := 1000 - r := randSliceUint32(uint(N)) max := uint32(Q) * (1 << 16) mid := int32(Q) * (1 << 15) + r := randSliceUint32WithMax(uint(N), max) for i := 0; i < N; i++ { - x := int32(r[i]%max) - mid + x := int32(r[i]) - mid y := montReduce(x) if modQ32(x) != modQ32(int32(y)*(1<<16)) { t.Fatalf("%d", x) diff --git a/pke/kyber/internal/common/ntt_test.go b/pke/kyber/internal/common/ntt_test.go index a678c428..81ef21ba 100644 --- a/pke/kyber/internal/common/ntt_test.go +++ b/pke/kyber/internal/common/ntt_test.go @@ -31,18 +31,18 @@ func BenchmarkInvNTTGeneric(b *testing.B) { } func (p *Poly) Rand() { - r := randSliceUint32(uint(N)) max := uint32(Q) + r := randSliceUint32WithMax(uint(N), max) for i := 0; i < N; i++ { - p[i] = int16(r[i] % max) + p[i] = int16(r[i]) } } func (p *Poly) RandAbsLeQ() { - r := randSliceUint32(uint(N)) max := 2 * uint32(Q) + r := randSliceUint32WithMax(uint(N), max) for i := 0; i < N; i++ { - p[i] = int16(int32(r[i]%max) - int32(Q)) + p[i] = int16(int32(r[i]) - int32(Q)) } } diff --git a/pke/kyber/internal/common/poly_test.go b/pke/kyber/internal/common/poly_test.go index d19633cc..fcce5fa7 100644 --- a/pke/kyber/internal/common/poly_test.go +++ b/pke/kyber/internal/common/poly_test.go @@ -7,10 +7,10 @@ import ( ) func (p *Poly) RandAbsLe9Q() { - r := randSliceUint32(uint(N)) max := 9 * uint32(Q) + r := randSliceUint32WithMax(uint(N), max) for i := 0; i < N; i++ { - p[i] = int16(int32(r[i] % max)) + p[i] = int16(int32(r[i])) } } @@ -27,9 +27,10 @@ func TestDecompressMessage(t *testing.T) { var m, m2 [PlaintextSize]byte var p Poly for i := 0; i < 1000; i++ { - _, err := rand.Read(m[:]) - if err != nil { + if n, err := rand.Read(m[:]); err != nil { t.Error(err) + } else if n != len(m) { + t.Fatal("short read from RNG") } p.DecompressMessage(m[:]) diff --git a/sign/dilithium/internal/common/field_test.go b/sign/dilithium/internal/common/field_test.go index aa837e74..269e3e98 100644 --- a/sign/dilithium/internal/common/field_test.go +++ b/sign/dilithium/internal/common/field_test.go @@ -4,22 +4,24 @@ import ( "crypto/rand" "encoding/binary" "flag" + "math" "testing" ) var runVeryLongTest = flag.Bool("very-long", false, "runs very long tests") -func randSliceUint32(N uint) []uint32 { - bytes := make([]uint8, 4*N) - n, err := rand.Read(bytes) - if err != nil { +func randSliceUint32(length uint) []uint32 { return randSliceUint32WithMax(length, math.MaxUint32) } + +func randSliceUint32WithMax(length uint, max uint32) []uint32 { + bytes := make([]uint8, 4*length) + if n, err := rand.Read(bytes); err != nil { panic(err) } else if n < len(bytes) { panic("short read from RNG") } - x := make([]uint32, N) + x := make([]uint32, length) for i := range x { - x[i] = binary.LittleEndian.Uint32(bytes[4*i:]) + x[i] = binary.LittleEndian.Uint32(bytes[4*i:]) % max } return x } diff --git a/sign/dilithium/internal/common/ntt_test.go b/sign/dilithium/internal/common/ntt_test.go index e760ef99..4524f245 100644 --- a/sign/dilithium/internal/common/ntt_test.go +++ b/sign/dilithium/internal/common/ntt_test.go @@ -3,11 +3,9 @@ package common import "testing" func (p *Poly) RandLe2Q() { - r := randSliceUint32(N) max := 2 * uint32(Q) - for i := uint(0); i < N; i++ { - p[i] = r[i] % max - } + r := randSliceUint32WithMax(N, max) + copy(p[:], r) } func TestNTTAgainstGeneric(t *testing.T) { diff --git a/sign/dilithium/internal/common/pack_test.go b/sign/dilithium/internal/common/pack_test.go index 715d9933..315e11ee 100644 --- a/sign/dilithium/internal/common/pack_test.go +++ b/sign/dilithium/internal/common/pack_test.go @@ -1,20 +1,14 @@ package common -import ( - "crypto/rand" - "testing" -) +import "testing" func TestPackLe16AgainstGeneric(t *testing.T) { var p Poly var buf1, buf2 [PolyLe16Size]byte - pp := make([]uint8, 256) for j := 0; j < 1000; j++ { - _, _ = rand.Read(pp) - for i := 0; i < 256; i++ { - p[i] = uint32(pp[i] & 0xF) - } + pp := randSliceUint32WithMax(N, 16) + copy(p[:], pp) p.PackLe16(buf1[:]) p.packLe16Generic(buf2[:]) if buf1 != buf2 {