diff --git a/common_test.go b/common_test.go index 3b5ffa8..b445897 100644 --- a/common_test.go +++ b/common_test.go @@ -1,12 +1,10 @@ package mls import ( - "bytes" "encoding/hex" - "fmt" - "reflect" - "runtime" "testing" + + "github.com/stretchr/testify/require" ) type TestEnum uint8 @@ -19,10 +17,10 @@ var ( func TestValidateEnum(t *testing.T) { err := validateEnum(TestEnumVal0, TestEnumVal0, TestEnumVal1) - assertNotError(t, err, "Failed to recognize known enum value") + require.Nil(t, err) err = validateEnum(TestEnumInvalid, TestEnumVal0, TestEnumVal1) - assertError(t, err, "Failed to flag invalid enum value") + require.Error(t, err) } ////////// @@ -34,80 +32,3 @@ func unhex(h string) []byte { } return b } - -////////// - -func assertTrue(t *testing.T, test bool, msg string) { - t.Helper() - prefix := string("") - for i := 1; ; i++ { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - prefix = fmt.Sprintf("%v: %d\n", file, line) + prefix - } - if !test { - t.Fatalf(prefix + msg) - } -} - -func assertError(t *testing.T, err error, msg string) { - t.Helper() - assertTrue(t, err != nil, msg) -} - -func assertNotError(t *testing.T, err error, msg string) { - t.Helper() - if err != nil { - msg += ": " + err.Error() - } - assertTrue(t, err == nil, msg) -} - -func assertPanic(t *testing.T, f func(), msg string) { - defer func() { - if r := recover(); r == nil { - assertTrue(t, false, msg) - } - }() - - f() -} - -func assertNil(t *testing.T, x interface{}, msg string) { - t.Helper() - assertTrue(t, x == nil, msg) -} - -func assertNotNil(t *testing.T, x interface{}, msg string) { - t.Helper() - assertTrue(t, x != nil, msg) -} - -func assertEquals(t *testing.T, a, b interface{}) { - t.Helper() - assertTrue(t, a == b, fmt.Sprintf("%+v != %+v", a, b)) -} - -func assertByteEquals(t *testing.T, a, b []byte) { - t.Helper() - assertTrue(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b))) -} - -func assertNotByteEquals(t *testing.T, a, b []byte) { - t.Helper() - assertTrue(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b))) -} - -func assertDeepEquals(t *testing.T, a, b interface{}) { - t.Helper() - assertTrue(t, reflect.DeepEqual(a, b), fmt.Sprintf("%+v != %+v", a, b)) -} - -func assertSameType(t *testing.T, a, b interface{}) { - t.Helper() - A := reflect.TypeOf(a) - B := reflect.TypeOf(b) - assertTrue(t, A == B, fmt.Sprintf("%s != %s", A.Name(), B.Name())) -} diff --git a/credential_test.go b/credential_test.go index 0d9e9c3..4935015 100644 --- a/credential_test.go +++ b/credential_test.go @@ -2,29 +2,31 @@ package mls import ( "testing" + + "github.com/stretchr/testify/require" ) func TestBasicCredential(t *testing.T) { identity := []byte("res ipsa") scheme := Ed25519 priv, err := scheme.Generate() - assertNotError(t, err, "Error generating private key") + require.Nil(t, err) cred := NewBasicCredential(identity, scheme, &priv) - assertTrue(t, cred.Equals(*cred), "Credential not equal to self") - assertEquals(t, cred.Type(), CredentialTypeBasic) - assertEquals(t, cred.Scheme(), scheme) - assertDeepEquals(t, *cred.PublicKey(), priv.PublicKey) + require.True(t, cred.Equals(*cred)) + require.Equal(t, cred.Type(), CredentialTypeBasic) + require.Equal(t, cred.Scheme(), scheme) + require.Equal(t, *cred.PublicKey(), priv.PublicKey) } func TestCredentialErrorCases(t *testing.T) { cred0 := Credential{nil, nil} - assertTrue(t, !cred0.Equals(cred0), "Bad credentials should not be equal") - assertEquals(t, cred0.Type(), CredentialTypeInvalid) - assertPanic(t, func() { cred0.PublicKey() }, "Public key for bad credential") - assertPanic(t, func() { cred0.Scheme() }, "Scheme for bad credential") + require.True(t, !cred0.Equals(cred0)) + require.Equal(t, cred0.Type(), CredentialTypeInvalid) + require.Panics(t, func() { cred0.PublicKey() }) + require.Panics(t, func() { cred0.Scheme() }) _, err := cred0.MarshalTLS() - assertError(t, err, "Marshal for bad credential") + require.Error(t, err) } diff --git a/crypto_test.go b/crypto_test.go index 4682842..e08a223 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/bifurcation/mint/syntax" + "github.com/stretchr/testify/require" ) var supportedSuites = []CipherSuite{ @@ -36,9 +37,7 @@ func TestDigest(t *testing.T) { } d := suite.digest(in) - if !bytes.Equal(d, out) { - t.Fatalf("Incorrect digest: %x != %x", d, out) - } + require.Equal(t, d, out) } } @@ -75,24 +74,16 @@ func TestEncryptDecrypt(t *testing.T) { } aead, err := suite.newAEAD(key) - if err != nil { - t.Fatalf("Error creating AEAD: %v", err) - } + require.Nil(t, err) // Test encryption encrypted := aead.Seal(nil, nonce, pt, aad) - if !bytes.Equal(ct, encrypted) { - t.Fatalf("Incorrect encryption: %x != %x", ct, encrypted) - } + require.Equal(t, ct, encrypted) // Test decryption decrypted, err := aead.Open(nil, nonce, ct, aad) - if err != nil { - t.Fatalf("Error in decryption: %v", err) - } - if !bytes.Equal(pt, decrypted) { - t.Fatalf("Incorrect decryption: %x != %x", pt, decrypted) - } + require.Nil(t, err) + require.Equal(t, pt, decrypted) } } @@ -109,17 +100,17 @@ func TestHPKE(t *testing.T) { encryptDecrypt := func(suite CipherSuite) func(t *testing.T) { return func(t *testing.T) { priv, err := suite.hpke().Generate() - assertNotError(t, err, "Error generating HPKE key") + require.Nil(t, err) priv, err = suite.hpke().Derive(seed) - assertNotError(t, err, "Error deriving HPKE key") + require.Nil(t, err) encrypted, err := suite.hpke().Encrypt(priv.PublicKey, aad, original) - assertNotError(t, err, "Error in HPKE encryption") + require.Nil(t, err) decrypted, err := suite.hpke().Decrypt(priv, aad, encrypted) - assertNotError(t, err, "Error in HPKE decryption") - assertByteEquals(t, original, decrypted) + require.Nil(t, err) + require.Equal(t, original, decrypted) } } @@ -135,16 +126,16 @@ func TestSignVerify(t *testing.T) { signVerify := func(scheme SignatureScheme) func(t *testing.T) { return func(t *testing.T) { priv, err := scheme.Generate() - assertNotError(t, err, "Error generating signing key") + require.Nil(t, err) priv, err = scheme.Derive(seed) - assertNotError(t, err, "Error generating signing key") + require.Nil(t, err) signature, err := scheme.Sign(&priv, message) - assertNotError(t, err, "Error signing") + require.Nil(t, err) verified := scheme.Verify(&priv.PublicKey, message, signature) - assertTrue(t, verified, "Signature failed to verify") + require.True(t, verified) } } @@ -194,32 +185,32 @@ func generateCryptoVectors(t *testing.T) []byte { priv, err = tc.CipherSuite.hpke().Derive(tv.DeriveKeyPairSeed) tc.DeriveKeyPairPub = priv.PublicKey - assertNotError(t, err, "Error deriving HPKE key pair") + require.Nil(t, err) tc.HPKEOut, err = tc.CipherSuite.hpke().Encrypt(tc.DeriveKeyPairPub, tv.HPKEAAD, tv.HPKEPlaintext) - assertNotError(t, err, "Error in HPKE encryption") + require.Nil(t, err) } vec, err := syntax.Marshal(tv) - assertNotError(t, err, "Error marshaling test vectors") + require.Nil(t, err) return vec } func verifyCryptoVectors(t *testing.T, data []byte) { var tv CryptoTestVectors _, err := syntax.Unmarshal(data, &tv) - assertNotError(t, err, "Malformed crypto test vectors") + require.Nil(t, err) for _, tc := range tv.Cases { hkdfExtractOut := tc.CipherSuite.hkdfExtract(tv.HKDFExtractSalt, tv.HKDFExtractIKM) - assertByteEquals(t, hkdfExtractOut, tc.HKDFExtractOut) + require.Equal(t, hkdfExtractOut, tc.HKDFExtractOut) priv, err = tc.CipherSuite.hpke().Derive(tv.DeriveKeyPairSeed) - assertNotError(t, err, "Error deriving HPKE key pair") - assertByteEquals(t, priv.PublicKey.Data, tc.DeriveKeyPairPub.Data) + require.Nil(t, err) + require.Equal(t, priv.PublicKey.Data, tc.DeriveKeyPairPub.Data) plaintext, err := tc.CipherSuite.hpke().Decrypt(priv, tv.HPKEAAD, tc.HPKEOut) - assertNotError(t, err, "Error in HPKE decryption") - assertDeepEquals(t, plaintext, tv.HPKEPlaintext) + require.Nil(t, err) + require.Equal(t, plaintext, tv.HPKEPlaintext) } } diff --git a/key-schedule_test.go b/key-schedule_test.go index 80fe2fe..ea4b036 100644 --- a/key-schedule_test.go +++ b/key-schedule_test.go @@ -2,8 +2,10 @@ package mls import ( "bytes" - "github.com/bifurcation/mint/syntax" "testing" + + "github.com/bifurcation/mint/syntax" + "github.com/stretchr/testify/require" ) // XXX(rlb): Uncomment this to see a graphical illustration of how the @@ -40,38 +42,38 @@ func TestKeySchedule(t *testing.T) { targetGeneration := uint32(3) checkEpoch := func(epoch *keyScheduleEpoch, size leafCount) { - assertEquals(t, epoch.Suite, suite) - assertEquals(t, len(epoch.EpochSecret), secretSize) - assertEquals(t, len(epoch.SenderDataSecret), secretSize) - assertEquals(t, len(epoch.SenderDataKey), keySize) - assertEquals(t, len(epoch.HandshakeSecret), secretSize) - assertEquals(t, len(epoch.ApplicationSecret), secretSize) - assertEquals(t, len(epoch.ConfirmationKey), secretSize) - assertEquals(t, len(epoch.InitSecret), secretSize) - assertNotNil(t, epoch.HandshakeKeys, "Missing handshake keys") - assertNotNil(t, epoch.HandshakeKeys, "Missing application keys") + require.Equal(t, epoch.Suite, suite) + require.Equal(t, len(epoch.EpochSecret), secretSize) + require.Equal(t, len(epoch.SenderDataSecret), secretSize) + require.Equal(t, len(epoch.SenderDataKey), keySize) + require.Equal(t, len(epoch.HandshakeSecret), secretSize) + require.Equal(t, len(epoch.ApplicationSecret), secretSize) + require.Equal(t, len(epoch.ConfirmationKey), secretSize) + require.Equal(t, len(epoch.InitSecret), secretSize) + require.NotNil(t, epoch.HandshakeKeys) + require.NotNil(t, epoch.HandshakeKeys) for i := leafIndex(0); i < leafIndex(size); i += 1 { // Test successful generation hs, err := epoch.HandshakeKeys.Get(i, targetGeneration) - assertNotError(t, err, "Error in handshake key generation") - assertEquals(t, len(hs.Key), keySize) - assertEquals(t, len(hs.Nonce), nonceSize) + require.Nil(t, err) + require.Equal(t, len(hs.Key), keySize) + require.Equal(t, len(hs.Nonce), nonceSize) app, err := epoch.ApplicationKeys.Get(i, targetGeneration) - assertNotError(t, err, "Error in handshake key generation") - assertEquals(t, len(app.Key), keySize) - assertEquals(t, len(app.Nonce), nonceSize) + require.Nil(t, err) + require.Equal(t, len(app.Key), keySize) + require.Equal(t, len(app.Nonce), nonceSize) epoch.HandshakeKeys.Erase(i, targetGeneration) epoch.ApplicationKeys.Erase(i, targetGeneration) // Test forward secrecy _, err = epoch.HandshakeKeys.Get(i, targetGeneration) - assertError(t, err, "Reused handshake key") + require.Error(t, err) _, err = epoch.ApplicationKeys.Get(i, targetGeneration) - assertError(t, err, "Reused handshake key") + require.Error(t, err) } } @@ -83,37 +85,37 @@ func TestKeySchedule(t *testing.T) { // Check that marshal/unmarshal works epoch2m, err := syntax.Marshal(epoch2) - assertNotError(t, err, "Error in key schedule marshal") + require.Nil(t, err) var epoch2u keyScheduleEpoch _, err = syntax.Unmarshal(epoch2m, &epoch2u) - assertNotError(t, err, "Error in key schedule unmarshal") + require.Nil(t, err) epoch2u.enableKeySources() // Verify that the contents match (not the group key generators) - assertDeepEquals(t, epoch2.Suite, epoch2u.Suite) - assertDeepEquals(t, epoch2.EpochSecret, epoch2u.EpochSecret) - assertDeepEquals(t, epoch2.SenderDataSecret, epoch2u.SenderDataSecret) - assertDeepEquals(t, epoch2.SenderDataKey, epoch2u.SenderDataKey) - assertDeepEquals(t, epoch2.HandshakeSecret, epoch2u.HandshakeSecret) - assertDeepEquals(t, epoch2.ApplicationSecret, epoch2u.ApplicationSecret) - assertDeepEquals(t, epoch2.ConfirmationKey, epoch2u.ConfirmationKey) - assertDeepEquals(t, epoch2.InitSecret, epoch2u.InitSecret) - assertDeepEquals(t, epoch2.HandshakeBaseKeys, epoch2u.HandshakeBaseKeys) - assertDeepEquals(t, epoch2.ApplicationBaseKeys, epoch2u.ApplicationBaseKeys) - assertDeepEquals(t, epoch2.HandshakeRatchets, epoch2u.HandshakeRatchets) - assertDeepEquals(t, epoch2.ApplicationRatchets, epoch2u.ApplicationRatchets) + require.Equal(t, epoch2.Suite, epoch2u.Suite) + require.Equal(t, epoch2.EpochSecret, epoch2u.EpochSecret) + require.Equal(t, epoch2.SenderDataSecret, epoch2u.SenderDataSecret) + require.Equal(t, epoch2.SenderDataKey, epoch2u.SenderDataKey) + require.Equal(t, epoch2.HandshakeSecret, epoch2u.HandshakeSecret) + require.Equal(t, epoch2.ApplicationSecret, epoch2u.ApplicationSecret) + require.Equal(t, epoch2.ConfirmationKey, epoch2u.ConfirmationKey) + require.Equal(t, epoch2.InitSecret, epoch2u.InitSecret) + require.Equal(t, epoch2.HandshakeBaseKeys, epoch2u.HandshakeBaseKeys) + require.Equal(t, epoch2.ApplicationBaseKeys, epoch2u.ApplicationBaseKeys) + require.Equal(t, epoch2.HandshakeRatchets, epoch2u.HandshakeRatchets) + require.Equal(t, epoch2.ApplicationRatchets, epoch2u.ApplicationRatchets) // Verify that we can't get a key for the target generation (because it's // already consumed) _, err = epoch2u.HandshakeKeys.Get(0, targetGeneration) - assertError(t, err, "Replayed an already-used key") + require.Error(t, err) // Verify that we can get one for the next epoch, and it's the same as the // original key schedule would have produced _, err = epoch2u.HandshakeKeys.Get(0, targetGeneration+1) - assertNotError(t, err, "Failed to get the next key") + require.Nil(t, err) } /// @@ -159,7 +161,7 @@ func generateKeyScheduleVectors(t *testing.T) []byte { } encCtx, err := syntax.Marshal(baseGrpCtx) - assertNotError(t, err, "grp context marshal") + require.Nil(t, err) tv.NumEpochs = 50 tv.TargetGeneration = 3 tv.BaseGroupContext = encCtx @@ -216,19 +218,19 @@ func generateKeyScheduleVectors(t *testing.T) []byte { } vec, err := syntax.Marshal(tv) - assertNotError(t, err, "Error marshaling test vectors") + require.Nil(t, err) return vec } func verifyKeyScheduleVectors(t *testing.T, data []byte) { var tv KsTestVectors _, err := syntax.Unmarshal(data, &tv) - assertNotError(t, err, "Malformed message test vectors") + require.Nil(t, err) for _, tc := range tv.Cases { suite := tc.CipherSuite var grpCtx GroupContext _, err := syntax.Unmarshal(tv.BaseGroupContext, &grpCtx) - assertNotError(t, err, "grpCtx unmarshal") + require.Nil(t, err) var myEpoch keyScheduleEpoch myEpoch.Suite = suite myEpoch.InitSecret = tv.BaseInitSecret @@ -236,25 +238,25 @@ func verifyKeyScheduleVectors(t *testing.T, data []byte) { ctx, _ := syntax.Marshal(grpCtx) myEpoch = myEpoch.Next(epoch.NumMembers, epoch.UpdateSecret, ctx) // check the secrets - assertByteEquals(t, myEpoch.EpochSecret, epoch.EpochSecret) - assertByteEquals(t, myEpoch.SenderDataSecret, epoch.SenderDataSecret) - assertByteEquals(t, myEpoch.SenderDataKey, epoch.SenderDataKey) - assertByteEquals(t, myEpoch.HandshakeSecret, epoch.HandshakeSecret) - assertByteEquals(t, myEpoch.ApplicationSecret, epoch.AppSecret) - assertByteEquals(t, myEpoch.ConfirmationKey, epoch.ConfirmationKey) - assertByteEquals(t, myEpoch.InitSecret, epoch.InitSecret) + require.Equal(t, myEpoch.EpochSecret, epoch.EpochSecret) + require.Equal(t, myEpoch.SenderDataSecret, epoch.SenderDataSecret) + require.Equal(t, myEpoch.SenderDataKey, epoch.SenderDataKey) + require.Equal(t, myEpoch.HandshakeSecret, epoch.HandshakeSecret) + require.Equal(t, myEpoch.ApplicationSecret, epoch.AppSecret) + require.Equal(t, myEpoch.ConfirmationKey, epoch.ConfirmationKey) + require.Equal(t, myEpoch.InitSecret, epoch.InitSecret) //check the keys for i := 0; leafCount(i) < epoch.NumMembers; i++ { hs, err := myEpoch.HandshakeKeys.Get(leafIndex(i), tv.TargetGeneration) - assertNotError(t, err, "hs keys") - assertByteEquals(t, hs.Key, epoch.HandshakeKeys[i].Key) - assertByteEquals(t, hs.Nonce, epoch.HandshakeKeys[i].Nonce) + require.Nil(t, err) + require.Equal(t, hs.Key, epoch.HandshakeKeys[i].Key) + require.Equal(t, hs.Nonce, epoch.HandshakeKeys[i].Nonce) as, err := myEpoch.ApplicationKeys.Get(leafIndex(i), tv.TargetGeneration) - assertNotError(t, err, "as keys") - assertByteEquals(t, as.Key, epoch.AppKeys[i].Key) - assertByteEquals(t, as.Nonce, epoch.AppKeys[i].Nonce) + require.Nil(t, err) + require.Equal(t, as.Key, epoch.AppKeys[i].Key) + require.Equal(t, as.Nonce, epoch.AppKeys[i].Nonce) } grpCtx.Epoch += 1 diff --git a/messages_test.go b/messages_test.go index c5927ff..ad0988f 100644 --- a/messages_test.go +++ b/messages_test.go @@ -2,8 +2,10 @@ package mls import ( "bytes" - "github.com/bifurcation/mint/syntax" "testing" + + "github.com/bifurcation/mint/syntax" + "github.com/stretchr/testify/require" ) var ( @@ -155,11 +157,11 @@ var ( func roundTrip(original interface{}, decoded interface{}) func(t *testing.T) { return func(t *testing.T) { encoded, err := syntax.Marshal(original) - assertNotError(t, err, "Fail to Marshal") + require.Nil(t, err) _, err = syntax.Unmarshal(encoded, decoded) - assertNotError(t, err, "Fail to Unmarshal") - assertDeepEquals(t, decoded, original) + require.Nil(t, err) + require.Equal(t, decoded, original) } } @@ -182,9 +184,9 @@ func TestMessagesMarshalUnmarshal(t *testing.T) { func TestWelcomeMarshalUnMarshalWithDecryption(t *testing.T) { // a tree with 2 members treeAB := newTestRatchetTree(t, supportedSuites[0], [][]byte{secretA, secretB}, []Credential{credA, credB}) - assertTrue(t, treeAB.size() == 2, "size mismatch") - assertEquals(t, *treeAB.GetCredential(leafIndex(0)), credA) - assertEquals(t, *treeAB.GetCredential(leafIndex(1)), credB) + require.Equal(t, treeAB.size(), leafCount(2)) + require.Equal(t, *treeAB.GetCredential(leafIndex(0)), credA) + require.Equal(t, *treeAB.GetCredential(leafIndex(1)), credB) cs := supportedSuites[0] secret, _ := getRandomBytes(32) @@ -213,12 +215,12 @@ func TestWelcomeMarshalUnMarshalWithDecryption(t *testing.T) { // it matches. ekp := w2.EncryptedKeyPackages[0] pt, err := cs.hpke().Decrypt(ikPriv, []byte{}, ekp.EncryptedPackage) - assertNotError(t, err, "decryption error") + require.Nil(t, err) w2kp := new(KeyPackage) _, err = syntax.Unmarshal(pt, w2kp) - assertNotError(t, err, "unmarshal failure for decrypted KeyPackage") - assertByteEquals(t, epochSecret, w2kp.EpochSecret) + require.Nil(t, err) + require.Equal(t, epochSecret, w2kp.EpochSecret) } /// @@ -257,21 +259,21 @@ type MessageTestVectors struct { //helpers func groupInfoMatch(t *testing.T, l, r GroupInfo) { - assertByteEquals(t, l.GroupID, r.GroupID) - assertEquals(t, l.Epoch, r.Epoch) - assertTrue(t, l.Tree.Equals(&r.Tree), "tree unequal") - assertByteEquals(t, l.ConfirmedTranscriptHash, r.ConfirmedTranscriptHash) - assertByteEquals(t, l.InterimTranscriptHash, r.InterimTranscriptHash) - assertByteEquals(t, l.Confirmation, r.Confirmation) - assertEquals(t, l.SignerIndex, r.SignerIndex) - assertByteEquals(t, l.Signature, r.Signature) + require.Equal(t, l.GroupID, r.GroupID) + require.Equal(t, l.Epoch, r.Epoch) + require.True(t, l.Tree.Equals(&r.Tree)) + require.Equal(t, l.ConfirmedTranscriptHash, r.ConfirmedTranscriptHash) + require.Equal(t, l.InterimTranscriptHash, r.InterimTranscriptHash) + require.Equal(t, l.Confirmation, r.Confirmation) + require.Equal(t, l.SignerIndex, r.SignerIndex) + require.Equal(t, l.Signature, r.Signature) } func commitMatch(t *testing.T, l, r Commit) { - assertDeepEquals(t, l.Adds, r.Adds) - assertDeepEquals(t, l.Removes, r.Removes) - assertDeepEquals(t, l.Updates, r.Updates) - assertDeepEquals(t, l.Ignored, r.Ignored) + require.Equal(t, l.Adds, r.Adds) + require.Equal(t, l.Removes, r.Removes) + require.Equal(t, l.Updates, r.Updates) + require.Equal(t, l.Ignored, r.Ignored) } /// Gen and Verify @@ -297,12 +299,12 @@ func generateMessageVectors(t *testing.T) []byte { scheme := schemes[i] // hpke priv, err := suite.hpke().Derive(tv.DHSeed) - assertNotError(t, err, "priv key failure") + require.Nil(t, err) pub := priv.PublicKey // identity sigPriv, err := scheme.Derive(tv.SigSeed) - assertNotError(t, err, "sigPriv failure") + require.Nil(t, err) sigPub := sigPriv.PublicKey bc := &BasicCredential{ @@ -317,7 +319,7 @@ func generateMessageVectors(t *testing.T) []byte { []Credential{cred, cred, cred, cred}) err = ratchetTree.BlankPath(leafIndex(2), true) - assertNotError(t, err, "rtree blank path") + require.Nil(t, err) dp, _ := ratchetTree.Encap(leafIndex(0), []byte{}, tv.Random) @@ -331,7 +333,7 @@ func generateMessageVectors(t *testing.T) []byte { } cikM, err := syntax.Marshal(cik) - assertNotError(t, err, "cik marshal") + require.Nil(t, err) // Welcome @@ -347,24 +349,24 @@ func generateMessageVectors(t *testing.T) []byte { } giM, err := syntax.Marshal(gi) - assertNotError(t, err, "grpInfo marshal") + require.Nil(t, err) kp := KeyPackage{ EpochSecret: tv.Random, } kpM, err := syntax.Marshal(kp) - assertNotError(t, err, "keyy package marshal") + require.Nil(t, err) encPayload, err := suite.hpke().Encrypt(pub, []byte{}, tv.Random) - assertNotError(t, err, "encrypt ekp") + require.Nil(t, err) ekp := EncryptedKeyPackage{ ClientInitKeyHash: tv.Random, EncryptedPackage: encPayload, } ekpM, err := syntax.Marshal(ekp) - assertNotError(t, err, "encrypted key package marshal") + require.Nil(t, err) var welcome Welcome welcome.Version = SupportedVersionMLS10 @@ -373,7 +375,7 @@ func generateMessageVectors(t *testing.T) []byte { welcome.EncryptedGroupInfo = tv.Random welM, err := syntax.Marshal(welcome) - assertNotError(t, err, "welcome marshal") + require.Nil(t, err) // proposals addProposal := &Proposal{ @@ -393,7 +395,7 @@ func generateMessageVectors(t *testing.T) []byte { addHs.Signature = Signature{tv.Random} addM, err := syntax.Marshal(addHs) - assertNotError(t, err, "add HS marshal") + require.Nil(t, err) updateProposal := &Proposal{ Update: &UpdateProposal{ @@ -412,7 +414,7 @@ func generateMessageVectors(t *testing.T) []byte { updateHs.Signature = Signature{tv.Random} updateM, err := syntax.Marshal(updateHs) - assertNotError(t, err, "update HS marshal") + require.Nil(t, err) removeProposal := &Proposal{ Remove: &RemoveProposal{ @@ -431,7 +433,7 @@ func generateMessageVectors(t *testing.T) []byte { removeHs.Signature = Signature{tv.Random} remM, err := syntax.Marshal(removeHs) - assertNotError(t, err, "remove HS marshal") + require.Nil(t, err) // commit proposal := []ProposalID{{tv.Random}, {tv.Random}} @@ -444,7 +446,7 @@ func generateMessageVectors(t *testing.T) []byte { } commitM, err := syntax.Marshal(commit) - assertNotError(t, err, "commit marshal") + require.Nil(t, err) //MlsCiphertext ct := MLSCiphertext{ @@ -457,7 +459,7 @@ func generateMessageVectors(t *testing.T) []byte { } ctM, err := syntax.Marshal(ct) - assertNotError(t, err, "MLSCiphertext marshal") + require.Nil(t, err) tc := MessageTestCase{ CipherSuite: suite, @@ -477,24 +479,24 @@ func generateMessageVectors(t *testing.T) []byte { } vec, err := syntax.Marshal(tv) - assertNotError(t, err, "Error marshaling test vectors") + require.Nil(t, err) return vec } func verifyMessageVectors(t *testing.T, data []byte) { var tv MessageTestVectors _, err := syntax.Unmarshal(data, &tv) - assertNotError(t, err, "Malformed message test vectors") + require.Nil(t, err) for _, tc := range tv.Cases { suite := tc.CipherSuite scheme := tc.SignatureScheme priv, err := suite.hpke().Derive(tv.DHSeed) - assertNotError(t, err, "hpke error") + require.Nil(t, err) pub := priv.PublicKey sigPriv, err := scheme.Derive(tv.SigSeed) - assertNotError(t, err, "sig error") + require.Nil(t, err) sigPub := sigPriv.PublicKey bc := &BasicCredential{ @@ -509,7 +511,7 @@ func verifyMessageVectors(t *testing.T, data []byte) { []Credential{cred, cred, cred, cred}) err = ratchetTree.BlankPath(leafIndex(2), true) - assertNotError(t, err, "rtree blank path") + require.Nil(t, err) dp, _ := ratchetTree.Encap(leafIndex(0), []byte{}, tv.Random) @@ -522,8 +524,8 @@ func verifyMessageVectors(t *testing.T, data []byte) { Signature: Signature{tv.Random}, } cikM, err := syntax.Marshal(cik) - assertNotError(t, err, "cik marshal") - assertByteEquals(t, cikM, tc.ClientInitKey) + require.Nil(t, err) + require.Equal(t, cikM, tc.ClientInitKey) // Welcome gi := &GroupInfo{ @@ -539,7 +541,7 @@ func verifyMessageVectors(t *testing.T, data []byte) { var giWire GroupInfo _, err = syntax.Unmarshal(tc.GroupInfo, &giWire) - assertNotError(t, err, "groupInfo unmarshal") + require.Nil(t, err) groupInfoMatch(t, *gi, giWire) @@ -548,18 +550,18 @@ func verifyMessageVectors(t *testing.T, data []byte) { } kpM, err := syntax.Marshal(kp) - assertNotError(t, err, "key package marshal") - assertByteEquals(t, kpM, tc.KeyPackage) + require.Nil(t, err) + require.Equal(t, kpM, tc.KeyPackage) encPayload, err := suite.hpke().Encrypt(pub, []byte{}, tv.Random) - assertNotError(t, err, "encrypt ekp") + require.Nil(t, err) ekp := EncryptedKeyPackage{ ClientInitKeyHash: tv.Random, EncryptedPackage: encPayload, } var ekpWire EncryptedKeyPackage syntax.Unmarshal(tc.EncryptedKeyPackage, &ekpWire) - assertByteEquals(t, ekp.ClientInitKeyHash, ekpWire.ClientInitKeyHash) + require.Equal(t, ekp.ClientInitKeyHash, ekpWire.ClientInitKeyHash) var welcome Welcome welcome.Version = SupportedVersionMLS10 @@ -569,9 +571,9 @@ func verifyMessageVectors(t *testing.T, data []byte) { var welWire Welcome syntax.Unmarshal(tc.Welcome, &welWire) - assertTrue(t, welcome.CipherSuite == welWire.CipherSuite, "welcome suite") - assertTrue(t, welcome.Version == welWire.Version, "welcome version") - assertByteEquals(t, welcome.EncryptedGroupInfo, welWire.EncryptedGroupInfo) + require.Equal(t, welcome.CipherSuite, welWire.CipherSuite) + require.Equal(t, welcome.Version, welWire.Version) + require.Equal(t, welcome.EncryptedGroupInfo, welWire.EncryptedGroupInfo) // proposals addProposal := &Proposal{ @@ -591,8 +593,8 @@ func verifyMessageVectors(t *testing.T, data []byte) { addHs.Signature = Signature{tv.Random} addM, err := syntax.Marshal(addHs) - assertNotError(t, err, "add HS marshal") - assertByteEquals(t, addM, tc.AddProposal) + require.Nil(t, err) + require.Equal(t, addM, tc.AddProposal) updateProposal := &Proposal{ Update: &UpdateProposal{ @@ -611,8 +613,8 @@ func verifyMessageVectors(t *testing.T, data []byte) { updateHs.Signature = Signature{tv.Random} updateM, err := syntax.Marshal(updateHs) - assertNotError(t, err, "update HS marshal") - assertByteEquals(t, updateM, tc.UpdateProposal) + require.Nil(t, err) + require.Equal(t, updateM, tc.UpdateProposal) removeProposal := &Proposal{ Remove: &RemoveProposal{ @@ -630,8 +632,8 @@ func verifyMessageVectors(t *testing.T, data []byte) { } removeHs.Signature = Signature{tv.Random} remM, err := syntax.Marshal(removeHs) - assertNotError(t, err, "remove HS marshal") - assertByteEquals(t, remM, tc.RemoveProposal) + require.Nil(t, err) + require.Equal(t, remM, tc.RemoveProposal) // commit proposal := []ProposalID{{tv.Random}, {tv.Random}} @@ -644,7 +646,7 @@ func verifyMessageVectors(t *testing.T, data []byte) { } var commitWire Commit _, err = syntax.Unmarshal(tc.Commit, &commitWire) - assertNotError(t, err, "commit marshal") + require.Nil(t, err) commitMatch(t, commit, commitWire) //MlsCiphertext @@ -658,7 +660,7 @@ func verifyMessageVectors(t *testing.T, data []byte) { } ctM, err := syntax.Marshal(ct) - assertNotError(t, err, "MLSCiphertext marshal") - assertByteEquals(t, ctM, tc.MLSCiphertext) + require.Nil(t, err) + require.Equal(t, ctM, tc.MLSCiphertext) } } diff --git a/ratchet-tree_test.go b/ratchet-tree_test.go index bbb7f12..6a6e498 100644 --- a/ratchet-tree_test.go +++ b/ratchet-tree_test.go @@ -4,8 +4,10 @@ import ( "bytes" "crypto/rand" "fmt" - "github.com/bifurcation/mint/syntax" "testing" + + "github.com/bifurcation/mint/syntax" + "github.com/stretchr/testify/require" ) func newTestRatchetTree(t *testing.T, cs CipherSuite, secrets [][]byte, creds []Credential) *RatchetTree { @@ -115,17 +117,17 @@ var ( func TestRatchetTreeOneMember(t *testing.T) { tree := newTestRatchetTree(t, supportedSuites[0], [][]byte{secretA}, []Credential{credA}) - assertTrue(t, tree.size() == 1, "size mismatch") - assertEquals(t, *tree.GetCredential(leafIndex(0)), credA) + require.Equal(t, tree.size(), leafCount(1)) + require.Equal(t, *tree.GetCredential(leafIndex(0)), credA) } func TestRatchetTreeMultipleMembers(t *testing.T) { tree := newTestRatchetTree(t, supportedSuites[0], allSecrets, allCreds) - assertTrue(t, tree.size() == 4, "size mismatch") - assertEquals(t, *tree.GetCredential(leafIndex(0)), credA) - assertEquals(t, *tree.GetCredential(leafIndex(1)), credB) - assertEquals(t, *tree.GetCredential(leafIndex(2)), credC) - assertEquals(t, *tree.GetCredential(leafIndex(3)), credD) + require.Equal(t, tree.size(), leafCount(4)) + require.Equal(t, *tree.GetCredential(leafIndex(0)), credA) + require.Equal(t, *tree.GetCredential(leafIndex(1)), credB) + require.Equal(t, *tree.GetCredential(leafIndex(2)), credC) + require.Equal(t, *tree.GetCredential(leafIndex(3)), credD) } func TestRatchetTreeByExtension(t *testing.T) { @@ -139,9 +141,9 @@ func TestRatchetTreeByExtension(t *testing.T) { tree.AddLeaf(leafIndex(0), &privA.PublicKey, &credA) _, rootA := tree.Encap(leafIndex(0), []byte{}, secretA) - assertByteEquals(t, rootA, secretA) - assertByteEquals(t, tree.RootHash(), hashA) - assertEquals(t, *tree.GetCredential(leafIndex(0)), credA) + require.Equal(t, rootA, secretA) + require.Equal(t, tree.RootHash(), hashA) + require.Equal(t, *tree.GetCredential(leafIndex(0)), credA) // Add B privB, err := cs.hpke().Derive(secretB) @@ -150,13 +152,13 @@ func TestRatchetTreeByExtension(t *testing.T) { } tree.AddLeaf(leafIndex(1), &privB.PublicKey, &credB) _, rootB := tree.Encap(leafIndex(1), []byte{}, secretB) - assertByteEquals(t, rootB, secretAB) - assertByteEquals(t, tree.RootHash(), hashAB) - assertEquals(t, *tree.GetCredential(leafIndex(1)), credB) + require.Equal(t, rootB, secretAB) + require.Equal(t, tree.RootHash(), hashAB) + require.Equal(t, *tree.GetCredential(leafIndex(1)), credB) // direct check directAB := newTestRatchetTree(t, supportedSuites[0], allSecrets[:2], allCreds[:2]) - assertTrue(t, directAB.Equals(tree), "TreeAB mismatch") + require.True(t, directAB.Equals(tree)) // Add C privC, err := cs.hpke().Derive(secretC) @@ -165,13 +167,13 @@ func TestRatchetTreeByExtension(t *testing.T) { } tree.AddLeaf(leafIndex(2), &privC.PublicKey, &credC) _, rootC := tree.Encap(leafIndex(2), []byte{}, secretC) - assertByteEquals(t, rootC, secretABC) - assertEquals(t, *tree.GetCredential(leafIndex(2)), credC) - assertByteEquals(t, tree.RootHash(), hashABC) + require.Equal(t, rootC, secretABC) + require.Equal(t, *tree.GetCredential(leafIndex(2)), credC) + require.Equal(t, tree.RootHash(), hashABC) // direct check directABC := newTestRatchetTree(t, supportedSuites[0], allSecrets[:3], allCreds[:3]) - assertTrue(t, directABC.Equals(tree), "TreeABC mismatch") + require.True(t, directABC.Equals(tree)) // Add D privD, err := cs.hpke().Derive(secretD) @@ -181,25 +183,25 @@ func TestRatchetTreeByExtension(t *testing.T) { tree.AddLeaf(leafIndex(3), &privD.PublicKey, &credD) _, rootD := tree.Encap(leafIndex(3), []byte{}, secretD) - assertByteEquals(t, rootD, secretABCD) - assertByteEquals(t, tree.RootHash(), hashABCD) - assertEquals(t, *tree.GetCredential(leafIndex(0)), credA) - assertEquals(t, *tree.GetCredential(leafIndex(1)), credB) - assertEquals(t, *tree.GetCredential(leafIndex(2)), credC) - assertEquals(t, *tree.GetCredential(leafIndex(3)), credD) + require.Equal(t, rootD, secretABCD) + require.Equal(t, tree.RootHash(), hashABCD) + require.Equal(t, *tree.GetCredential(leafIndex(0)), credA) + require.Equal(t, *tree.GetCredential(leafIndex(1)), credB) + require.Equal(t, *tree.GetCredential(leafIndex(2)), credC) + require.Equal(t, *tree.GetCredential(leafIndex(3)), credD) // direct check directABCD := newTestRatchetTree(t, supportedSuites[0], allSecrets, allCreds) - assertTrue(t, directABCD.Equals(tree), "TreeABCD mismatch") + require.True(t, directABCD.Equals(tree)) } func TestRatchetTreeBySerialization(t *testing.T) { before := newTestRatchetTree(t, supportedSuites[0], allSecrets, allCreds) after := newRatchetTree(supportedSuites[0]) enc, err := before.MarshalTLS() - assertNotError(t, err, "Tree marshal error") + require.Nil(t, err) _, err = after.UnmarshalTLS(enc) - assertTrue(t, before.Equals(after), "Tree mismatch") + require.True(t, before.Equals(after)) } func TestRatchetTreeEncryptDecrypt(t *testing.T) { @@ -237,10 +239,10 @@ func TestRatchetTreeEncryptDecrypt(t *testing.T) { // Verify that all trees are equal and the invariants are satisfied for i, tree := range trees { - assertTrue(t, tree.Equals(trees[0]), fmt.Sprintf("Tree %d differs", i)) - assertEquals(t, int(tree.size()), size) - assertTrue(t, tree.checkCredentials(), "credential check failed") - assertTrue(t, tree.checkInvariant(leafIndex(i*2)), "check invariant failed") + require.True(t, tree.Equals(trees[0]), fmt.Sprintf("Tree %d differs", i)) + require.Equal(t, int(tree.size()), size) + require.True(t, tree.checkCredentials()) + require.True(t, tree.checkInvariant(leafIndex(i*2))) } // verify encrypt/decrypt @@ -253,9 +255,9 @@ func TestRatchetTreeEncryptDecrypt(t *testing.T) { } decryptedSecret, err := dstTree.Decap(leafIndex(i), []byte{}, path) - assertNotError(t, err, "Error in decap()") - assertByteEquals(t, rootSecret, decryptedSecret) - assertTrue(t, srcTree.Equals(dstTree), "Failed update on decap()") + require.Nil(t, err) + require.Equal(t, rootSecret, decryptedSecret) + require.True(t, srcTree.Equals(dstTree)) } } } @@ -269,40 +271,40 @@ func TestRatchetTreeSecrets(t *testing.T) { // Marshal the private and public parts marshaledPub, err := syntax.Marshal(tree) - assertNotError(t, err, "Error in public marshal") + require.Nil(t, err) marshaledPriv, err := syntax.Marshal(secrets) - assertNotError(t, err, "Error in private marshal") + require.Nil(t, err) // Unmarshal the private and public parts tree2 := newRatchetTree(suite) secrets2 := TreeSecrets{} _, err = syntax.Unmarshal(marshaledPub, tree2) - assertNotError(t, err, "Error in public unmarshal") + require.Nil(t, err) _, err = syntax.Unmarshal(marshaledPriv, &secrets2) - assertNotError(t, err, "Error in public unmarshal") + require.Nil(t, err) // Reassemble the tree tree2.SetSecrets(secrets2) // Compare public and private contents - assertDeepEquals(t, tree, tree2) + require.Equal(t, tree, tree2) } func TestRatchetTree_Clone(t *testing.T) { tree := newTestRatchetTree(t, supportedSuites[0], allSecrets, allCreds) - assertTrue(t, tree.size() == 4, "size mismatch") + require.Equal(t, tree.size(), leafCount(4)) cloned := tree.clone() - assertTrue(t, cloned.size() == 4, "size mismatch") - assertEquals(t, *cloned.GetCredential(leafIndex(0)), credA) - assertEquals(t, *cloned.GetCredential(leafIndex(1)), credB) - assertEquals(t, *cloned.GetCredential(leafIndex(2)), credC) - assertEquals(t, *cloned.GetCredential(leafIndex(3)), credD) + require.Equal(t, cloned.size(), leafCount(4)) + require.Equal(t, *cloned.GetCredential(leafIndex(0)), credA) + require.Equal(t, *cloned.GetCredential(leafIndex(1)), credB) + require.Equal(t, *cloned.GetCredential(leafIndex(2)), credC) + require.Equal(t, *cloned.GetCredential(leafIndex(3)), credD) - assertTrue(t, tree.Equals(cloned), "clone is not equaled to its parent") + require.True(t, tree.Equals(cloned)) } /// @@ -376,7 +378,7 @@ func generateRatchetTreeVectors(t *testing.T) []byte { for j := 0; j < leaves; j++ { id := []byte{byte(j)} sigPriv, err := scheme.Derive(id) - assertNotError(t, err, "sig error") + require.Nil(t, err) sigPub := sigPriv.PublicKey bc := &BasicCredential{ Identity: id, @@ -386,9 +388,9 @@ func generateRatchetTreeVectors(t *testing.T) []byte { cred := Credential{Basic: bc} tc.Credentials = append(tc.Credentials, cred) priv, err := suite.hpke().Derive(tv.LeafSecrets[j].Data) - assertNotError(t, err, "hpke error") + require.Nil(t, err) err = tree.AddLeaf(leafIndex(j), &priv.PublicKey, &cred) - assertNotError(t, err, "add leaf") + require.Nil(t, err) tree.Encap(leafIndex(j), []byte{}, tv.LeafSecrets[j].Data) tc.Trees = append(tc.Trees, treeToTreeNode(tree)) } @@ -396,7 +398,7 @@ func generateRatchetTreeVectors(t *testing.T) []byte { // blank out the even numbered leaves for j := 0; j < leaves; j += 2 { err := tree.BlankPath(leafIndex(j), true) - assertNotError(t, err, "blank path") + require.Nil(t, err) tc.Trees = append(tc.Trees, treeToTreeNode(tree)) } @@ -405,27 +407,27 @@ func generateRatchetTreeVectors(t *testing.T) []byte { } vec, err := syntax.Marshal(tv) - assertNotError(t, err, "Error marshaling test vectors") + require.Nil(t, err) return vec } -func assertTreeEq(t *testing.T, tn []TreeNode, tree *RatchetTree) bool { +func requireTreesEqual(t *testing.T, tn []TreeNode, tree *RatchetTree) { nodes := tree.Nodes - assertTrue(t, len(tn) == len(nodes), "nodes size mismatch") + require.Equal(t, len(tn), len(nodes)) for i := 0; i < len(tn); i++ { - assertTrue(t, bytes.Equal(tn[i].Hash, nodes[i].Hash), "hash mismatch") + require.True(t, bytes.Equal(tn[i].Hash, nodes[i].Hash)) if !nodes[i].blank() { - assertTrue(t, bytes.Equal(tn[i].PubKey.Data, nodes[i].Node.PublicKey.Data), "pubkey mismatch") + require.Equal(t, tn[i].PubKey.Data, nodes[i].Node.PublicKey.Data) } else { - assertTrue(t, tn[i].PubKey == nil, "blank node mismatch") + require.Nil(t, tn[i].PubKey) } } - return true } + func verifyRatchetTreeVectors(t *testing.T, data []byte) { var tv RatchetTreeVectors _, err := syntax.Unmarshal(data, &tv) - assertNotError(t, err, "Malformed test vectors") + require.Nil(t, err) for _, tc := range tv.Cases { suite := tc.CipherSuite @@ -433,19 +435,19 @@ func verifyRatchetTreeVectors(t *testing.T, data []byte) { var tci = 0 for i := 0; i < len(tv.LeafSecrets); i++ { priv, err := suite.hpke().Derive(tv.LeafSecrets[i].Data) - assertNotError(t, err, "derive hpke") + require.Nil(t, err) err = tree.AddLeaf(leafIndex(i), &priv.PublicKey, &tc.Credentials[i]) - assertNotError(t, err, "add leaf") + require.Nil(t, err) tree.Encap(leafIndex(i), []byte{}, tv.LeafSecrets[i].Data) - assertTrue(t, assertTreeEq(t, tc.Trees[tci], tree), "tree unequal") + requireTreesEqual(t, tc.Trees[tci], tree) tci += 1 } // blank even numbered leaves for j := 0; j < len(tv.LeafSecrets); j += 2 { err := tree.BlankPath(leafIndex(j), true) - assertNotError(t, err, "blank path") - assertTreeEq(t, tc.Trees[tci], tree) + require.Nil(t, err) + requireTreesEqual(t, tc.Trees[tci], tree) tci += 1 } } diff --git a/state_test.go b/state_test.go index 7a6ebdb..cdd421a 100644 --- a/state_test.go +++ b/state_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/bifurcation/mint/syntax" + "github.com/stretchr/testify/require" ) var ( @@ -35,7 +36,7 @@ func setup(t *testing.T) StateTest { cred := NewBasicCredential(userId, scheme, &sigPriv) //cik gen cik, err := NewClientInitKey(suite, cred) - assertNotError(t, err, "NewClientInitKey error") + require.Nil(t, err) // save all the materials stateTest.identityPrivs = append(stateTest.identityPrivs, sigPriv) stateTest.credentials = append(stateTest.credentials, *cred) @@ -57,18 +58,18 @@ func setupGroup(t *testing.T) StateTest { for i := 1; i < groupSize; i++ { add := states[0].Add(stateTest.clientInitKeys[i]) _, err := states[0].Handle(add) - assertNotError(t, err, "add failed") + require.Nil(t, err) } // commit the adds secret, _ := getRandomBytes(32) _, welcome, next, err := states[0].Commit(secret) - assertNotError(t, err, "commit add proposals failed") + require.Nil(t, err) states[0] = *next // initialize the new joiners from the welcome for i := 1; i < groupSize; i++ { s, err := NewJoinedState([]ClientInitKey{stateTest.clientInitKeys[i]}, *welcome) - assertNotError(t, err, "initializing the state from welcome failed") + require.Nil(t, err) states = append(states, *s) } stateTest.states = states @@ -76,7 +77,7 @@ func setupGroup(t *testing.T) StateTest { // Verify that the states are all equivalent for _, lhs := range stateTest.states { for _, rhs := range stateTest.states { - assertTrue(t, lhs.Equals(rhs), "State mismatch") + require.True(t, lhs.Equals(rhs)) } } @@ -98,26 +99,26 @@ func TestStateTwoPerson(t *testing.T) { // add the second participant add := first0.Add(stateTest.clientInitKeys[1]) _, err := first0.Handle(add) - assertNotError(t, err, "handle add failed") + require.Nil(t, err) // commit adding the second participant secret, _ := getRandomBytes(32) _, welcome, first1, err := first0.Commit(secret) - assertNotError(t, err, "state_test. commit failed") + require.Nil(t, err) // Initialize the second participant from the Welcome second1, err := NewJoinedState([]ClientInitKey{stateTest.clientInitKeys[1]}, *welcome) - assertNotError(t, err, "state_test: state creation using Welcome failed") + require.Nil(t, err) // Verify that the two states are equivalent - assertTrue(t, first1.Equals(*second1), "State mismatch") + require.True(t, first1.Equals(*second1)) /// Verify that they can exchange protected messages ct, err := first1.Protect(testMessage) - assertNotError(t, err, "protect error") + require.Nil(t, err) pt, err := second1.Unprotect(ct) - assertNotError(t, err, "unprotect failure") - assertByteEquals(t, pt, testMessage) + require.Nil(t, err) + require.Equal(t, pt, testMessage) } func TestStateMarshalUnmarshal(t *testing.T) { @@ -127,50 +128,50 @@ func TestStateMarshalUnmarshal(t *testing.T) { add := alice0.Add(stateTest.clientInitKeys[1]) _, err := alice0.Handle(add) - assertNotError(t, err, "Initial add failed") + require.Nil(t, err) secret, _ := getRandomBytes(32) _, welcome1, alice1, err := alice0.Commit(secret) - assertNotError(t, err, "Initial commit failed") + require.Nil(t, err) // Marshal Alice's secret state alice1priv, err := syntax.Marshal(alice1.GetSecrets()) - assertNotError(t, err, "Error marshaling Alice private values") + require.Nil(t, err) // Initialize Bob generate an Update+Commit bob1, err := NewJoinedState([]ClientInitKey{stateTest.clientInitKeys[1]}, *welcome1) - assertNotError(t, err, "state_test: state creation using Welcome failed") - assertTrue(t, alice1.Equals(*bob1), "State mismatch") + require.Nil(t, err) + require.True(t, alice1.Equals(*bob1)) update := bob1.Update(secret) _, err = bob1.Handle(update) - assertNotError(t, err, "Update failed at Bob") + require.Nil(t, err) commit, _, bob2, err := bob1.Commit(secret) - assertNotError(t, err, "Update commit generation failed") + require.Nil(t, err) // Recreate Alice from Welcome and secrets alice1aPriv := StateSecrets{} _, err = syntax.Unmarshal(alice1priv, &alice1aPriv) - assertNotError(t, err, "Error unmarshaling Alice private values") + require.Nil(t, err) alice1a, err := NewStateFromWelcomeAndSecrets(*welcome1, alice1aPriv) - assertNotError(t, err, "Error importing group info from Welcome") + require.Nil(t, err) // Verify that Alice can process Bob's Update+Commit _, err = alice1a.Handle(update) - assertNotError(t, err, "Update failed at Alice") + require.Nil(t, err) alice2, err := alice1a.Handle(commit) - assertNotError(t, err, "Update commit handling failed") + require.Nil(t, err) // Verify that Alice and Bob can exchange protected messages /// Verify that they can exchange protected messages ct, err := alice2.Protect(testMessage) - assertNotError(t, err, "protect error") + require.Nil(t, err) pt, err := bob2.Unprotect(ct) - assertNotError(t, err, "unprotect failure") - assertByteEquals(t, pt, testMessage) + require.Nil(t, err) + require.Equal(t, pt, testMessage) } func TestStateMulti(t *testing.T) { @@ -183,25 +184,25 @@ func TestStateMulti(t *testing.T) { for i := 1; i < groupSize; i++ { add := stateTest.states[0].Add(stateTest.clientInitKeys[i]) _, err := stateTest.states[0].Handle(add) - assertNotError(t, err, "add failed") + require.Nil(t, err) } // commit the adds secret, _ := getRandomBytes(32) _, welcome, next, err := stateTest.states[0].Commit(secret) - assertNotError(t, err, "commit add proposals failed") + require.Nil(t, err) stateTest.states[0] = *next // initialize the new joiners from the welcome for i := 1; i < groupSize; i++ { s, err := NewJoinedState([]ClientInitKey{stateTest.clientInitKeys[i]}, *welcome) - assertNotError(t, err, "initializing the state from welcome failed") + require.Nil(t, err) stateTest.states = append(stateTest.states, *s) } // Verify that the states are all equivalent for _, lhs := range stateTest.states { for _, rhs := range stateTest.states { - assertTrue(t, lhs.Equals(rhs), "State mismatch") + require.True(t, lhs.Equals(rhs)) } } @@ -213,7 +214,7 @@ func TestStateMulti(t *testing.T) { continue } pt, _ := o.Unprotect(ct) - assertByteEquals(t, pt, testMessage) + require.Equal(t, pt, testMessage) } } } @@ -231,7 +232,7 @@ func TestStateCipherNegotiation(t *testing.T) { var aliceCiks []ClientInitKey for _, s := range aliceSuites { cik, err := NewClientInitKey(s, &aliceCred) - assertNotError(t, err, "NewClientInitKey error") + require.Nil(t, err) aliceCiks = append(aliceCiks, *cik) } @@ -247,20 +248,20 @@ func TestStateCipherNegotiation(t *testing.T) { var bobCiks []ClientInitKey for _, s := range bobSuites { cik, err := NewClientInitKey(s, &bobCred) - assertNotError(t, err, "NewClientInitKey error") + require.Nil(t, err) bobCiks = append(bobCiks, *cik) } // Bob should choose P-256 secret, _ := getRandomBytes(32) welcome, bobState, err := negotiateWithPeer(groupId, bobCiks, aliceCiks, secret) - assertNotError(t, err, "state negotiation failed") + require.Nil(t, err) // Alice should also arrive at P-256 aliceState, err := NewJoinedState(aliceCiks, *welcome) - assertNotError(t, err, "state negotiation failed") + require.Nil(t, err) - assertTrue(t, aliceState.Equals(*bobState), "states are unequal") + require.True(t, aliceState.Equals(*bobState)) } func TestStateUpdate(t *testing.T) { @@ -270,23 +271,23 @@ func TestStateUpdate(t *testing.T) { update := state.Update(leafSecret) state.Handle(update) commit, _, next, err := state.Commit(leafSecret) - assertNotError(t, err, "creator commit error") + require.Nil(t, err) for j, other := range stateTest.states { if j == i { stateTest.states[j] = *next } else { _, err := other.Handle(update) - assertNotError(t, err, "Update recipient proposal fail") + require.Nil(t, err) newState, err := other.Handle(commit) - assertNotError(t, err, "Update recipient commit fail") + require.Nil(t, err) stateTest.states[j] = *newState } } for _, s := range stateTest.states { - assertTrue(t, stateTest.states[0].Equals(s), "states unequal") + require.True(t, stateTest.states[0].Equals(s)) } } } @@ -298,7 +299,7 @@ func TestStateRemove(t *testing.T) { stateTest.states[i].Handle(remove) secret, _ := getRandomBytes(32) commit, _, next, err := stateTest.states[i].Commit(secret) - assertNotError(t, err, "remove error") + require.Nil(t, err) stateTest.states = stateTest.states[:len(stateTest.states)-1] for j, state := range stateTest.states { @@ -307,13 +308,13 @@ func TestStateRemove(t *testing.T) { } else { state.Handle(remove) newState, err := state.Handle(commit) - assertNotError(t, err, "remove processing error by others") + require.Nil(t, err) stateTest.states[j] = *newState } } for _, s := range stateTest.states { - assertTrue(t, s.Equals(stateTest.states[0]), "states unequal") + require.True(t, s.Equals(stateTest.states[0])) } } } diff --git a/stream_test.go b/stream_test.go index f063fd6..e5c2254 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,6 +2,8 @@ package mls import ( "testing" + + "github.com/stretchr/testify/require" ) type streamTestVec struct { @@ -26,25 +28,25 @@ func TestWriteStream(t *testing.T) { w := NewWriteStream() err := w.Write(streamTestInputs.val1) - assertNotError(t, err, "Error writing to stream") + require.Nil(t, err) err = w.Write(streamTestInputs.val2) - assertNotError(t, err, "Error writing to stream") + require.Nil(t, err) err = w.Write(streamTestInputs.val3) - assertNotError(t, err, "Error writing to stream") + require.Nil(t, err) err = w.Write(streamTestInputs.val4) - assertNotError(t, err, "Error writing to stream") + require.Nil(t, err) encoded := w.Data() - assertByteEquals(t, encoded, streamTestInputs.encoded) + require.Equal(t, encoded, streamTestInputs.encoded) w2 := NewWriteStream() err = w2.WriteAll(streamTestInputs.val1, streamTestInputs.val2, streamTestInputs.val3, streamTestInputs.val4) - assertNotError(t, err, "Error in WriteAll") - assertByteEquals(t, w.Data(), w2.Data()) + require.Nil(t, err) + require.Equal(t, w.Data(), w2.Data()) } func TestReadStream(t *testing.T) { @@ -52,27 +54,27 @@ func TestReadStream(t *testing.T) { var val1 uint8 read, err := r.Read(&val1) - assertNotError(t, err, "Error reading from stream") - assertEquals(t, read, 1) - assertDeepEquals(t, val1, streamTestInputs.val1) + require.Nil(t, err) + require.Equal(t, read, 1) + require.Equal(t, val1, streamTestInputs.val1) var val2 uint16 read, err = r.Read(&val2) - assertNotError(t, err, "Error reading from stream") - assertEquals(t, read, 2) - assertDeepEquals(t, val2, streamTestInputs.val2) + require.Nil(t, err) + require.Equal(t, read, 2) + require.Equal(t, val2, streamTestInputs.val2) var val3 streamTestVec read, err = r.Read(&val3) - assertNotError(t, err, "Error reading from stream") - assertEquals(t, read, 5) - assertDeepEquals(t, val3, streamTestInputs.val3) + require.Nil(t, err) + require.Equal(t, read, 5) + require.Equal(t, val3, streamTestInputs.val3) var val4 uint32 read, err = r.Read(&val4) - assertNotError(t, err, "Error reading from stream") - assertEquals(t, read, 4) - assertDeepEquals(t, val4, streamTestInputs.val4) + require.Nil(t, err) + require.Equal(t, read, 4) + require.Equal(t, val4, streamTestInputs.val4) var val1a uint8 var val2a uint16 @@ -80,11 +82,11 @@ func TestReadStream(t *testing.T) { var val4a uint32 r2 := NewReadStream(streamTestInputs.encoded) read, err = r2.ReadAll(&val1a, &val2a, &val3a, &val4a) - assertNotError(t, err, "Error in ReadAll") - assertEquals(t, read, len(streamTestInputs.encoded)) - assertDeepEquals(t, val1, val1a) - assertDeepEquals(t, val2, val2a) - assertDeepEquals(t, val3, val3a) - assertDeepEquals(t, val4, val4a) + require.Nil(t, err) + require.Equal(t, read, len(streamTestInputs.encoded)) + require.Equal(t, val1, val1a) + require.Equal(t, val2, val2a) + require.Equal(t, val3, val3a) + require.Equal(t, val4, val4a) } diff --git a/test-vectors_test.go b/test-vectors_test.go index 6534673..2e24d73 100644 --- a/test-vectors_test.go +++ b/test-vectors_test.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/stretchr/testify/require" ) // To generate or verify test vectors, run `go test` with these environment @@ -80,7 +82,7 @@ func vectorGenerate(c TestVectorCase, testDir string) func(t *testing.T) { if len(testDir) != 0 { file := filepath.Join(testDir, c.Filename) err := ioutil.WriteFile(file, vec, 0644) - assertNotError(t, err, "Error writing test vectors") + require.Nil(t, err) } } } @@ -99,7 +101,7 @@ func vectorVerify(c TestVectorCase, testDir string) func(t *testing.T) { file := filepath.Join(testDir, c.Filename) fmt.Printf("Test File %v\n", file) vec, err := ioutil.ReadFile(file) - assertNotError(t, err, "Error reading test vectors") + require.Nil(t, err) // Verify test vectors c.Verify(t, vec) diff --git a/tree-math_test.go b/tree-math_test.go index c21ccad..abada79 100644 --- a/tree-math_test.go +++ b/tree-math_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/bifurcation/mint/syntax" + "github.com/stretchr/testify/require" ) // Precomputed answers for the tree on eleven elements: @@ -181,14 +182,14 @@ func generateTreeMathVectors(t *testing.T) []byte { } vec, err := syntax.Marshal(tv) - assertNotError(t, err, "Error marshaling test vectors") + require.Nil(t, err) return vec } func verifyTreeMathVectors(t *testing.T, data []byte) { var tv TreeMathTestVectors _, err := syntax.Unmarshal(data, &tv) - assertNotError(t, err, "Malformed tree math test vectors") + require.Nil(t, err) tvLen := int(nodeWidth(tv.NumLeaves)) if len(tv.Root) != int(tv.NumLeaves) || len(tv.Left) != tvLen || @@ -197,13 +198,13 @@ func verifyTreeMathVectors(t *testing.T, data []byte) { } for i := range tv.Root { - assertEquals(t, tv.Root[i], root(leafCount(i+1))) + require.Equal(t, tv.Root[i], root(leafCount(i+1))) } for i := range tv.Left { - assertEquals(t, tv.Left[i], left(nodeIndex(i))) - assertEquals(t, tv.Right[i], right(nodeIndex(i), tv.NumLeaves)) - assertEquals(t, tv.Parent[i], parent(nodeIndex(i), tv.NumLeaves)) - assertEquals(t, tv.Sibling[i], sibling(nodeIndex(i), tv.NumLeaves)) + require.Equal(t, tv.Left[i], left(nodeIndex(i))) + require.Equal(t, tv.Right[i], right(nodeIndex(i), tv.NumLeaves)) + require.Equal(t, tv.Parent[i], parent(nodeIndex(i), tv.NumLeaves)) + require.Equal(t, tv.Sibling[i], sibling(nodeIndex(i), tv.NumLeaves)) } }