diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go index ef78d5a2a1..7ef1ba0c7f 100644 --- a/mixing/mixclient/client.go +++ b/mixing/mixclient/client.go @@ -218,9 +218,9 @@ func (s *Session) waitForEpoch(ctx context.Context) (time.Time, error) { } } -// sleep blocks until the next synchronization point, or errors when the context -// is cancelled early. -func (s *Session) sleep(ctx context.Context, until time.Time) error { +// sleepUntil blocks until the next synchronization point, or errors when the +// context is cancelled early. +func (s *Session) sleepUntil(ctx context.Context, until time.Time) error { select { case <-ctx.Done(): return ctx.Err() @@ -487,11 +487,36 @@ func (s *Session) run(ctx context.Context, run uint32, expiry int64, prs []*wire return fmt.Errorf("submit KE: %w", err) } - if err := s.sleep(ctx, s.deadlines.recvKE); err != nil { - return err - } - kes := mp.ReceiveKEs(pairingID) - if len(kes) != len(vk) { + // In run 0, the run can proceed when there are KE messages from each + // peer selected for this session and each KE refers to known PR + // messages. This is done using mp.Receive() as it only finds + // messages matching the session. + // + // If this is not the case, we must find a different session that + // other peers are able to participate in. This must be a subset of + // the original PRs that peers have seen, and should optimize for + // including as many mixed outputs as possible. This is done using + // mp.ReceiveKEs(), which returns all KEs matching a PR pairing, even + // if they began in different sessions. + rcv := new(mixpool.Received) + rcv.Run = run + rcv.Sid = s.sid + rcv.KEs = make([]*wire.MsgMixKE, 0, len(vk)) + rcvCtx, rcvCtxCancel := context.WithDeadlineCause(ctx, + s.deadlines.recvKE, errRunStageTimeout) + err = mp.Receive(ctx, len(vk), rcv) + rcvCtxCancel() + kes := rcv.KEs + if len(kes) != len(vk) || errors.Is(err, errRunStageTimeout) { + if run == 0 { + // Based on the pairing data available, begin a new + // session. + if err := s.sleepUntil(ctx, s.deadlines.recvKE); err != nil { + return err + } + kes = mp.ReceiveKEs(pairingID) + // XXX + } // Blame peers s.logf("received %d KEs for %d peers; rerunning", len(kes), len(vk)) return errRerun @@ -499,25 +524,12 @@ func (s *Session) run(ctx context.Context, run uint32, expiry int64, prs []*wire if err != nil { return err } - - // In run 0, the run can proceed when there are KE messages from each - // peer selected for this session and each KE refers to known PR - // messages. - // - // If this is not the case, we must find a different session that other - // peers are able to participate in. This must be a subset of the - // original PRs that peers have seen, and should optimize for including - // as many mixed outputs as possible. - - if len(kes) != len(vk) { - // Blame peers - return errRerun - } sort.Slice(kes, func(i, j int) bool { a := identityIndices[kes[i].Identity] b := identityIndices[kes[j].Identity] return a < b }) + revealed := mixing.RevealedKeys{ ECDHPublicKeys: make([]*secp256k1.PublicKey, 0, len(vk)), Ciphertexts: make([]mixing.PQCiphertext, 0, len(vk)), @@ -559,17 +571,17 @@ func (s *Session) run(ctx context.Context, run uint32, expiry int64, prs []*wire } // Receive all ciphertext messages - rcv := new(mixpool.Received) - rcv.Run = run - rcv.Sid = s.sid + rcv.KEs = nil rcv.CTs = make([]*wire.MsgMixCT, 0, len(vk)) - rcvCtx, rcvCtxCancel := context.WithDeadlineCause(ctx, + rcv.RSs = nil // XXX + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(ctx, s.deadlines.recvCT, errRunStageTimeout) err = mp.Receive(ctx, len(vk), rcv) rcvCtxCancel() cts := rcv.CTs if len(cts) != len(vk) || errors.Is(err, errRunStageTimeout) { // Blame peers + s.logf("received %d CTs for %d peers; rerunning", len(cts), len(vk)) return errRerun } if err != nil { @@ -627,6 +639,7 @@ func (s *Session) run(ctx context.Context, run uint32, expiry int64, prs []*wire srs := rcv.SRs if len(srs) != len(vk) || errors.Is(err, errRunStageTimeout) { // Blame peers + s.logf("received %d SRs for %d peers; rerunning", len(srs), len(vk)) return errRerun } if err != nil { @@ -697,6 +710,7 @@ func (s *Session) run(ctx context.Context, run uint32, expiry int64, prs []*wire dcs := rcv.DCs if len(dcs) != len(vk) || errors.Is(err, errRunStageTimeout) { // Blame peers + s.logf("received %d DCs for %d peers; rerunning", len(dcs), len(vk)) return errRerun } if err != nil { @@ -741,6 +755,7 @@ func (s *Session) run(ctx context.Context, run uint32, expiry int64, prs []*wire var errMM missingMessage if errors.As(err, &errMM) { // Blame peers + s.logf("missing message; rerunning", len(cts), len(vk)) return errRerun } return err @@ -773,6 +788,7 @@ func (s *Session) run(ctx context.Context, run uint32, expiry int64, prs []*wire cms := rcv.CMs if len(cms) != len(vk) || errors.Is(err, errRunStageTimeout) { // Blame peers + s.logf("received %d CMs for %d peers; rerunning", len(cms), len(vk)) return errRerun } if err != nil {