Skip to content

Commit

Permalink
Merge 0c121b4 into 57fc9a2
Browse files Browse the repository at this point in the history
  • Loading branch information
bifurcation committed Mar 25, 2020
2 parents 57fc9a2 + 0c121b4 commit 61714f6
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 50 deletions.
25 changes: 21 additions & 4 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ type Commit struct {
Updates []ProposalID `tls:"head=2"`
Removes []ProposalID `tls:"head=2"`
Adds []ProposalID `tls:"head=2"`
Ignored []ProposalID `tls:"head=2"`
Path DirectPath
}

Expand All @@ -286,6 +285,24 @@ func (ct ContentType) ValidForTLS() error {
return validateEnum(ct, ContentTypeApplication, ContentTypeProposal, ContentTypeCommit)
}

type SenderType uint8

const (
SenderTypeInvalid SenderType = 0
SenderTypeMember SenderType = 1
SenderTypePreconfigured SenderType = 2
SenderTypeNewMember SenderType = 3
)

func (st SenderType) ValidForTLS() error {
return validateEnum(st, SenderTypeMember, SenderTypePreconfigured, SenderTypeNewMember)
}

type Sender struct {
Type SenderType
Sender uint32
}

type ApplicationData struct {
Data []byte `tls:"head=4"`
}
Expand Down Expand Up @@ -375,7 +392,7 @@ func (c *MLSPlaintextContent) UnmarshalTLS(data []byte) (int, error) {
type MLSPlaintext struct {
GroupID []byte `tls:"head=1"`
Epoch Epoch
Sender leafIndex
Sender Sender
AuthenticatedData []byte `tls:"head=4"`
Content MLSPlaintextContent
Signature Signature
Expand All @@ -391,7 +408,7 @@ func (pt MLSPlaintext) toBeSigned(ctx GroupContext) []byte {
err = s.Write(struct {
GroupID []byte `tls:"head=1"`
Epoch Epoch
Sender leafIndex
Sender Sender
AuthenticatedData []byte `tls:"head=4"`
Content MLSPlaintextContent
}{
Expand Down Expand Up @@ -427,7 +444,7 @@ func (pt MLSPlaintext) commitContent() []byte {
enc, err := syntax.Marshal(struct {
GroupId []byte `tls:"head=1"`
Epoch Epoch
Sender leafIndex
Sender Sender
Commit Commit
ContentType ContentType
}{
Expand Down
20 changes: 9 additions & 11 deletions messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,13 @@ var (
Updates: []ProposalID{{Hash: []byte{0x00, 0x01}}},
Removes: []ProposalID{{Hash: []byte{0x02, 0x03}}},
Adds: []ProposalID{{Hash: []byte{0x04, 0x05}}},
Ignored: []ProposalID{{Hash: []byte{0x06, 0x07}}},
Path: DirectPath{Nodes: nodes},
}

mlsPlaintextIn = &MLSPlaintext{
GroupID: []byte{0x01, 0x02, 0x03, 0x04},
Epoch: 1,
Sender: 4,
Sender: Sender{SenderTypeMember, 4},
AuthenticatedData: []byte{0xAA, 0xBB, 0xcc, 0xdd},
Content: MLSPlaintextContent{
Application: &ApplicationData{
Expand Down Expand Up @@ -245,6 +244,7 @@ type MessageTestCase struct {

type MessageTestVectors struct {
Epoch Epoch
SenderType SenderType
SignerIndex leafIndex
Removed leafIndex
UserId []byte `tls:"head=1"`
Expand Down Expand Up @@ -273,13 +273,13 @@ func commitMatch(t *testing.T, l, r Commit) {
require.Equal(t, l.Adds, r.Adds)
require.Equal(t, l.Removes, r.Removes)
require.Equal(t, l.Updates, r.Updates)
require.Equal(t, l.Ignored, r.Ignored)
}

/// Gen and Verify
func generateMessageVectors(t *testing.T) []byte {
tv := MessageTestVectors{
Epoch: 0xA0A1A2A3,
SenderType: SenderTypeMember,
SignerIndex: leafIndex(0xB0B1B2B3),
Removed: leafIndex(0xC0C1C2C3),
UserId: bytes.Repeat([]byte{0xD1}, 16),
Expand Down Expand Up @@ -387,7 +387,7 @@ func generateMessageVectors(t *testing.T) []byte {
addHs := MLSPlaintext{
GroupID: tv.GroupID,
Epoch: tv.Epoch,
Sender: tv.SignerIndex,
Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)},
Content: MLSPlaintextContent{
Proposal: addProposal,
},
Expand All @@ -406,7 +406,7 @@ func generateMessageVectors(t *testing.T) []byte {
updateHs := MLSPlaintext{
GroupID: tv.GroupID,
Epoch: tv.Epoch,
Sender: tv.SignerIndex,
Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)},
Content: MLSPlaintextContent{
Proposal: updateProposal,
},
Expand All @@ -425,7 +425,7 @@ func generateMessageVectors(t *testing.T) []byte {
removeHs := MLSPlaintext{
GroupID: tv.GroupID,
Epoch: tv.Epoch,
Sender: tv.SignerIndex,
Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)},
Content: MLSPlaintextContent{
Proposal: removeProposal,
},
Expand All @@ -441,7 +441,6 @@ func generateMessageVectors(t *testing.T) []byte {
Updates: proposal,
Removes: proposal,
Adds: proposal,
Ignored: proposal,
Path: *dp,
}

Expand Down Expand Up @@ -585,7 +584,7 @@ func verifyMessageVectors(t *testing.T, data []byte) {
addHs := MLSPlaintext{
GroupID: tv.GroupID,
Epoch: tv.Epoch,
Sender: tv.SignerIndex,
Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)},
Content: MLSPlaintextContent{
Proposal: addProposal,
},
Expand All @@ -605,7 +604,7 @@ func verifyMessageVectors(t *testing.T, data []byte) {
updateHs := MLSPlaintext{
GroupID: tv.GroupID,
Epoch: tv.Epoch,
Sender: tv.SignerIndex,
Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)},
Content: MLSPlaintextContent{
Proposal: updateProposal,
},
Expand All @@ -625,7 +624,7 @@ func verifyMessageVectors(t *testing.T, data []byte) {
removeHs := MLSPlaintext{
GroupID: tv.GroupID,
Epoch: tv.Epoch,
Sender: tv.SignerIndex,
Sender: Sender{tv.SenderType, uint32(tv.SignerIndex)},
Content: MLSPlaintextContent{
Proposal: removeProposal,
},
Expand All @@ -641,7 +640,6 @@ func verifyMessageVectors(t *testing.T, data []byte) {
Updates: proposal,
Removes: proposal,
Adds: proposal,
Ignored: proposal,
Path: *dp,
}
var commitWire Commit
Expand Down
89 changes: 54 additions & 35 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,6 @@ func (s *State) applyUpdateProposal(target leafIndex, update *UpdateProposal) er
return s.Tree.MergePublic(target, &update.LeafKey)
}

func (s *State) applyUpdateSecret(target leafIndex, secret []byte) error {
err := s.Tree.BlankPath(s.Index, false)
if err != nil {
return err
}
return s.Tree.Merge(s.Index, secret)
}

func (s *State) applyProposals(ids []ProposalID, processed map[string]bool) error {
for _, id := range ids {
pt, ok := s.findProposal(id)
Expand All @@ -405,17 +397,21 @@ func (s *State) applyProposals(ids []ProposalID, processed map[string]bool) erro
}

proposal := pt.Content.Proposal
var err error
switch proposal.Type() {
case ProposalTypeAdd:
err = s.applyAddProposal(proposal.Add)
err := s.applyAddProposal(proposal.Add)
if err != nil {
return err
}
case ProposalTypeUpdate:
if pt.Sender != s.Index {
if pt.Sender.Type != SenderTypeMember {
return fmt.Errorf("mls.state: update from non-member")
}

senderIndex := leafIndex(pt.Sender.Sender)
if senderIndex != s.Index {
// apply update from the given member
err := s.applyUpdateProposal(pt.Sender, proposal.Update)
err := s.applyUpdateProposal(senderIndex, proposal.Update)
if err != nil {
return err
}
Expand All @@ -426,12 +422,18 @@ func (s *State) applyProposals(ids []ProposalID, processed map[string]bool) erro
if !ok {
return fmt.Errorf("mls.state: self-update with no cached secret")
}
err = s.applyUpdateSecret(pt.Sender, updateSecret)

err := s.Tree.BlankPath(s.Index, false)
if err != nil {
return err
}

err = s.Tree.Merge(s.Index, updateSecret)
if err != nil {
return err
}
case ProposalTypeRemove:
err = s.applyRemoveProposal(proposal.Remove)
err := s.applyRemoveProposal(proposal.Remove)
if err != nil {
return err
}
Expand Down Expand Up @@ -479,7 +481,7 @@ func (s State) sign(p Proposal) *MLSPlaintext {
pt := &MLSPlaintext{
GroupID: s.GroupID,
Epoch: s.Epoch,
Sender: s.Index,
Sender: Sender{SenderTypeMember, uint32(s.Index)},
Content: MLSPlaintextContent{
Proposal: &p,
},
Expand All @@ -506,7 +508,7 @@ func (s *State) ratchetAndSign(op Commit, updateSecret []byte, prevGrpCtx GroupC
pt := &MLSPlaintext{
GroupID: s.GroupID,
Epoch: s.Epoch,
Sender: s.Index,
Sender: Sender{SenderTypeMember, uint32(s.Index)},
Content: MLSPlaintextContent{
Commit: &CommitData{
Commit: op,
Expand Down Expand Up @@ -556,7 +558,17 @@ func (s *State) Handle(pt *MLSPlaintext) (*State, error) {
return nil, fmt.Errorf("mls.state: epoch mismatch, have %v, got %v", s.Epoch, pt.Epoch)
}

sigPubKey := s.Tree.GetCredential(pt.Sender).PublicKey()
var sigPubKey *SignaturePublicKey
switch pt.Sender.Type {
case SenderTypeMember:
sigPubKey = s.Tree.GetCredential(leafIndex(pt.Sender.Sender)).PublicKey()

default:
// TODO(RLB): Support add sent by new member
// TODO(RLB): Support add/remove signed by preconfigured key
return nil, fmt.Errorf("mls.state: Unsupported sender type")
}

if !pt.verify(s.groupContext(), sigPubKey, s.Scheme) {
return nil, fmt.Errorf("invalid handshake message signature")
}
Expand All @@ -570,9 +582,11 @@ func (s *State) Handle(pt *MLSPlaintext) (*State, error) {

if contentType != ContentTypeCommit {
return nil, fmt.Errorf("mls.state: incorrect content type")
} else if pt.Sender.Type != SenderTypeMember {
return nil, fmt.Errorf("mls.state: commit from non-member")
}

if pt.Sender == s.Index {
if leafIndex(pt.Sender.Sender) == s.Index {
return nil, fmt.Errorf("mls.state: handle own commits with caching")
}

Expand All @@ -597,7 +611,8 @@ func (s *State) Handle(pt *MLSPlaintext) (*State, error) {
return nil, fmt.Errorf("mls.state: failure to create context %v", err)
}

updateSecret, err := next.Tree.Decap(pt.Sender, ctx, &commitData.Commit.Path)
senderIndex := leafIndex(pt.Sender.Sender)
updateSecret, err := next.Tree.Decap(senderIndex, ctx, &commitData.Commit.Path)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -644,6 +659,14 @@ func (s State) verifyConfirmation(confirmation []byte) bool {
return true
}

func applyGuard(nonceIn []byte, reuseGuard [4]byte) []byte {
nonceOut := dup(nonceIn)
for i := range reuseGuard {
nonceOut[i] ^= reuseGuard[i]
}
return nonceOut
}

func (s *State) encrypt(pt *MLSPlaintext) (*MLSCiphertext, error) {
var generation uint32
var keys keyAndNonce
Expand All @@ -656,12 +679,11 @@ func (s *State) encrypt(pt *MLSPlaintext) (*MLSCiphertext, error) {
return nil, fmt.Errorf("mls.state: encrypt unknown content type")
}

var reuseGuard [4]byte
rand.Read(reuseGuard[:])

stream := NewWriteStream()
// skipping error checks since we are trying plain integers
err := stream.Write(s.Index)
if err == nil {
err = stream.Write(generation)
}
err := stream.WriteAll(s.Index, generation, reuseGuard)
if err != nil {
return nil, fmt.Errorf("mls.state: sender data marshal failure %v", err)
}
Expand All @@ -687,7 +709,7 @@ func (s *State) encrypt(pt *MLSPlaintext) (*MLSCiphertext, error) {
aad := contentAAD(s.GroupID, s.Epoch, pt.Content.Type(),
pt.AuthenticatedData, senderDataNonce, sdCt)
aead, _ := s.CipherSuite.newAEAD(keys.Key)
contentCt := aead.Seal(nil, keys.Nonce, content, aad)
contentCt := aead.Seal(nil, applyGuard(keys.Nonce, reuseGuard), content, aad)

// set up MLSCipherText
ct := &MLSCiphertext{
Expand Down Expand Up @@ -723,11 +745,9 @@ func (s *State) decrypt(ct *MLSCiphertext) (*MLSPlaintext, error) {
// parse the senderData
var sender leafIndex
var generation uint32
var reuseGuard [4]byte
stream := NewReadStream(sd)
_, err = stream.Read(&sender)
if err == nil {
_, err = stream.Read(&generation)
}
_, err = stream.ReadAll(&sender, &generation, &reuseGuard)
if err != nil {
return nil, fmt.Errorf("mls.state: senderData unmarshal failure %v", err)
}
Expand Down Expand Up @@ -757,10 +777,8 @@ func (s *State) decrypt(ct *MLSCiphertext) (*MLSPlaintext, error) {

aad := contentAAD(ct.GroupID, ct.Epoch, ContentType(ct.ContentType),
ct.AuthenticatedData, ct.SenderDataNonce, ct.EncryptedSenderData)

aead, _ := s.CipherSuite.newAEAD(keys.Key)

content, err := aead.Open(nil, keys.Nonce, ct.Ciphertext, aad)
content, err := aead.Open(nil, applyGuard(keys.Nonce, reuseGuard), ct.Ciphertext, aad)
if err != nil {
return nil, fmt.Errorf("mls.state: content decryption failure %v", err)
}
Expand All @@ -781,7 +799,7 @@ func (s *State) decrypt(ct *MLSCiphertext) (*MLSPlaintext, error) {
pt := &MLSPlaintext{
GroupID: s.GroupID,
Epoch: s.Epoch,
Sender: sender,
Sender: Sender{SenderTypeMember, uint32(sender)},
AuthenticatedData: ct.AuthenticatedData,
Content: mlsContent,
Signature: signature,
Expand All @@ -793,7 +811,7 @@ func (s *State) Protect(data []byte) (*MLSCiphertext, error) {
pt := &MLSPlaintext{
GroupID: s.GroupID,
Epoch: s.Epoch,
Sender: s.Index,
Sender: Sender{SenderTypeMember, uint32(s.Index)},
Content: MLSPlaintextContent{
Application: &ApplicationData{
Data: data,
Expand All @@ -811,7 +829,8 @@ func (s *State) Unprotect(ct *MLSCiphertext) ([]byte, error) {
return nil, err
}

sigPubKey := s.Tree.GetCredential(pt.Sender).PublicKey()
senderIndex := leafIndex(pt.Sender.Sender)
sigPubKey := s.Tree.GetCredential(senderIndex).PublicKey()
if !pt.verify(s.groupContext(), sigPubKey, s.Scheme) {
return nil, fmt.Errorf("invalid message signature")
}
Expand Down

0 comments on commit 61714f6

Please sign in to comment.