diff --git a/mixing/mixclient/blame.go b/mixing/mixclient/blame.go index f02d37d45..248d633ff 100644 --- a/mixing/mixclient/blame.go +++ b/mixing/mixclient/blame.go @@ -48,7 +48,7 @@ func (e blamedIdentities) String() string { } func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { - c.logf("Blaming for sid=%x run=%d", sesRun.sid[:], sesRun.run) + c.logf("Blaming for sid=%x", sesRun.sid[:]) mp := c.mixpool prs := sesRun.prs @@ -79,7 +79,6 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { // Receive currently-revealed secrets rcv := new(mixpool.Received) - rcv.Run = sesRun.run rcv.Sid = sesRun.sid rcv.RSs = make([]*wire.MsgMixSecrets, 0, len(sesRun.prs)) _ = mp.Receive(ctx, 0, rcv) @@ -193,7 +192,7 @@ KELoop: // Recover or initialize PRNG from seed and the last run that // caused secrets to be generated. - p.prng = chacha20prng.New(p.rs.Seed[:], sesRun.prngRun) + p.prng = chacha20prng.New(p.rs.Seed[:], 0) // Recover derived key exchange from PRNG. p.kx, err = mixing.NewKX(p.prng) @@ -320,7 +319,7 @@ SRLoop: revealed.Ciphertexts = append(revealed.Ciphertexts, ct[p.myVk]) } sharedSecrets, err := p.kx.SharedSecrets(revealed, - sesRun.sid[:], sesRun.run, sesRun.mcounts) + sesRun.sid[:], 0, sesRun.mcounts) var decapErr *mixing.DecapsulateError if errors.As(err, &decapErr) { submittingID := p.id diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go index f8a682a5c..733d331fb 100644 --- a/mixing/mixclient/client.go +++ b/mixing/mixclient/client.go @@ -237,14 +237,10 @@ type pairedSessions struct { type sessionRun struct { sid [32]byte - run uint32 mtot uint32 // Whether this run must generate fresh KX keys, SR/DC messages. - // prngRun records the run (used as PRNG nonce) for the last run where - // a new PRNG and keys were generated. freshGen bool - prngRun uint32 deadlines @@ -578,7 +574,7 @@ func (c *Client) epochTicker(ctx context.Context) error { // results in "expired PR" errors. // Ideally, we would behave like dcrd and only remove sessions that have // mined mix transactions or are otherwise double spent in a block. - c.mixpool.RemoveConfirmedRuns() + c.mixpool.RemoveConfirmedSessions() c.expireMessages() for _, p := range c.pairings { @@ -768,7 +764,7 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir time.Sleep(10 * time.Second) c.logf("sid=%x removing mixed session completed with transaction %v", mixedSession.sid[:], mixedSession.cj.txHash) - c.mixpool.RemoveSession(mixedSession.sid, true) + c.mixpool.RemoveSession(mixedSession.sid) }() } if len(unmixedPeers) == 0 { @@ -822,10 +818,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir sesRun := sessionRun{ sid: sid, - run: 0, prs: prs, freshGen: true, - prngRun: 0, deadlines: d, mcounts: make([]uint32, 0, len(prs)), } @@ -865,12 +859,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir } currentRun.mtot = m - action := "created" - if currentRun.run != 0 { - action = "rerunning" - } - sesLog.logf("%s session for pairid=%x from %d total %d local PRs %s", - action, ps.pairing, len(prHashes), localPeerCount, prHashes) + sesLog.logf("created session for pairid=%x from %d total %d local PRs %s", + ps.pairing, len(prHashes), localPeerCount, prHashes) if localPeerCount == 0 { sesLog.logf("no more local peers") @@ -903,10 +893,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir rerun = &sessionRun{ sid: altses.sid, - run: 0, prs: altses.prs, freshGen: false, - prngRun: 0, deadlines: d, } continue @@ -956,7 +944,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) mp := c.wallet.Mixpool() sesRun := &ps.runs[len(ps.runs)-1] - run := sesRun.run prs := sesRun.prs d := &sesRun.deadlines @@ -984,7 +971,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) if err != nil { return err } - p.prng = chacha20prng.New(p.prngSeed[:], sesRun.prngRun) + p.prng = chacha20prng.New(p.prngSeed[:], 0) // Generate fresh keys from this run's PRNG p.kx, err = mixing.NewKX(p.prng) @@ -1026,14 +1013,14 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) for i := range p.srMsg { srMsgBytes[i] = p.srMsg[i].Bytes() } - rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, run, + rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, 0, *p.prngSeed, srMsgBytes, p.dcMsg) c.blake256HasherMu.Lock() commitment := rs.Commitment(c.blake256Hasher) c.blake256HasherMu.Unlock() ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed()) pqPub := *p.kx.PQPublicKey - ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, run, + ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, 0, uint32(identityIndices[*p.id]), ecdhPub, pqPub, commitment, seenPRs) @@ -1081,21 +1068,20 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) var kes []*wire.MsgMixKeyExchange recvKEs := func(sesRun *sessionRun) (kes []*wire.MsgMixKeyExchange, err error) { rcv := new(mixpool.Received) - rcv.Run = sesRun.run rcv.Sid = sesRun.sid rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) ctx, cancel := context.WithDeadline(ctx, d.recvKE) defer cancel() err = mp.Receive(ctx, len(sesRun.prs), rcv) if ctx.Err() != nil { - err = fmt.Errorf("session %x run-%d KE receive context cancelled: %w", - sesRun.sid[:], sesRun.run, ctx.Err()) + err = fmt.Errorf("session %x KE receive context cancelled: %w", + sesRun.sid[:], ctx.Err()) } return rcv.KEs, err } switch { - case run == 0: + case !*madePairing: // Receive KEs for the last attempted session. Local // peers may have been modified (new keys generated, and myVk // indexes changed) if this is a recreated session, and we @@ -1125,7 +1111,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) return c.alternateSession(ps.pairing, sesRun.prs, d) default: - // Receive KEs only for the agreed-upon session. kes, err = recvKEs(sesRun) if err != nil { return err @@ -1139,7 +1124,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } // Remove paired local peers from waiting pairing. - if run == 0 { + if !*madePairing { c.mu.Lock() if waiting := c.pairings[string(ps.pairing)]; waiting != nil { for id := range ps.localPeers { @@ -1151,10 +1136,8 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } c.mu.Unlock() - if !*madePairing { - c.pairingWG.Done() - *madePairing = true - } + *madePairing = true + c.pairingWG.Done() } sort.Slice(kes, func(i, j int) bool { @@ -1204,7 +1187,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } // Send ciphertext messages - ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, run, pqct, seenKEs) + ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, 0, pqct, seenKEs) p.ct = ct c.testHook(hookBeforePeerCTPublish, sesRun, p) return p.signAndSubmit(ct) @@ -1213,7 +1196,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // Receive all ciphertext messages rcv := new(mixpool.Received) rcv.Sid = sesRun.sid - rcv.Run = run rcv.KEs = nil rcv.CTs = make([]*wire.MsgMixCiphertexts, 0, len(prs)) rcvCtx, rcvCtxCancel := context.WithDeadline(ctx, d.recvCT) @@ -1265,7 +1247,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } // Derive shared secret keys - shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], run, sesRun.mcounts) + shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], 0, sesRun.mcounts) if err != nil { p.triggeredBlame = true return errTriggeredBlame @@ -1283,7 +1265,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) // Broadcast message commitment and exponential DC-mix vectors for slot // reservations. - sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, run, srMixBytes, seenCTs) + sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, 0, srMixBytes, seenCTs) p.sr = sr c.testHook(hookBeforePeerSRPublish, sesRun, p) return p.signAndSubmit(sr) @@ -1365,7 +1347,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } // Broadcast XOR DC-net vectors. - dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, run, p.dcNet, seenSRs) + dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, 0, p.dcNet, seenSRs) p.dc = dc c.testHook(hookBeforePeerDCPublish, sesRun, p) return p.signAndSubmit(dc) @@ -1445,7 +1427,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) } // Broadcast partially signed mix tx - cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, run, + cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, 0, p.coinjoin.Tx().Copy(), seenDCs) p.cm = cm return p.signAndSubmit(cm) @@ -1555,7 +1537,7 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, } err = c.forLocalPeers(ctx, sesRun, func(p *peer) error { fp := wire.NewMsgMixFactoredPoly(*p.id, sesRun.sid, - sesRun.run, rootBytes, seenSRs) + 0, rootBytes, seenSRs) return p.signAndSubmit(fp) }) return roots, err @@ -1568,7 +1550,6 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, expectedMessages := 1 rcv := &mixpool.Received{ Sid: sesRun.sid, - Run: sesRun.run, FPs: make([]*wire.MsgMixFactoredPoly, 0, sesRun.mtot), } roots := make([]*big.Int, 0, len(a)-1) @@ -1924,21 +1905,17 @@ func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities, revealedSecrets d := prevRun.deadlines d.restart() + unixEpoch := prevRun.epoch.Unix() + sid := mixing.SortPRsForSession(prs, uint64(unixEpoch)) + // mtot, peers, mcounts are all recalculated from the prs before // calling run() nextRun := &sessionRun{ - sid: prevRun.sid, - run: prevRun.run + 1, + sid: sid, + freshGen: revealedSecrets, prs: prs, deadlines: d, } - if revealedSecrets { - nextRun.freshGen = true - nextRun.prngRun = nextRun.run - } else { - nextRun.freshGen = false - nextRun.prngRun = prevRun.prngRun - } return nextRun } diff --git a/mixing/mixclient/client_test.go b/mixing/mixclient/client_test.go index 00be35bd8..9e827756e 100644 --- a/mixing/mixclient/client_test.go +++ b/mixing/mixclient/client_test.go @@ -238,7 +238,7 @@ func TestHonest(t *testing.T) { select { case <-ctx.Done(): return - case <-time.After(200 * time.Millisecond): + case <-time.After(500 * time.Millisecond): } } }() @@ -367,7 +367,7 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { select { case <-ctx.Done(): return - case <-time.After(200 * time.Millisecond): + case <-time.After(1000 * time.Millisecond): } } }() @@ -395,7 +395,7 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) { func TestCTDisruption(t *testing.T) { var misbehavingID identity testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if s.run != 0 || p.myVk != 0 { + if p.myVk != 0 { return } if misbehavingID != [33]byte{} { @@ -410,7 +410,10 @@ func TestCTDisruption(t *testing.T) { func TestCTLength(t *testing.T) { var misbehavingID identity testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if s.run != 0 || p.myVk != 0 { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { return } t.Logf("malicious peer %x: sending too few ciphertexts", p.id[:]) @@ -418,8 +421,12 @@ func TestCTLength(t *testing.T) { p.ct.Ciphertexts = p.ct.Ciphertexts[:len(p.ct.Ciphertexts)-1] }) + misbehavingID = identity{} testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) { - if s.run != 0 || p.myVk != 0 { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { return } t.Logf("malicious peer %x: sending too many ciphertexts", p.id[:]) @@ -431,7 +438,10 @@ func TestCTLength(t *testing.T) { func TestSRDisruption(t *testing.T) { var misbehavingID identity testDisruption(t, &misbehavingID, hookBeforePeerSRPublish, func(c *Client, s *sessionRun, p *peer) { - if s.run != 0 || p.myVk != 0 { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { return } t.Logf("malicious peer %x: flipping SR bit", p.id[:]) @@ -443,7 +453,10 @@ func TestSRDisruption(t *testing.T) { func TestDCDisruption(t *testing.T) { var misbehavingID identity testDisruption(t, &misbehavingID, hookBeforePeerDCPublish, func(c *Client, s *sessionRun, p *peer) { - if s.run != 0 || p.myVk != 0 { + if p.myVk != 0 { + return + } + if misbehavingID != [33]byte{} { return } t.Logf("malicious peer %x: flipping DC bit", p.id[:]) diff --git a/mixing/mixpool/mixpool.go b/mixing/mixpool/mixpool.go index 8a1e1187d..ab6d18dce 100644 --- a/mixing/mixpool/mixpool.go +++ b/mixing/mixpool/mixpool.go @@ -84,23 +84,19 @@ type orphan struct { type session struct { sid [32]byte - runs []runstate - expiry uint32 - bc broadcast -} - -type runstate struct { prs []chainhash.Hash counts [nmsgtypes]uint32 hashes map[chainhash.Hash]struct{} + expiry uint32 + bc broadcast } -func (r *runstate) countFor(t msgtype) uint32 { - return r.counts[t-1] +func (s *session) countFor(t msgtype) uint32 { + return s.counts[t-1] } -func (r *runstate) incrementCountFor(t msgtype) { - r.counts[t-1]++ +func (s *session) incrementCountFor(t msgtype) { + s.counts[t-1]++ } type broadcast struct { @@ -248,7 +244,7 @@ func (p *Pool) MixPRs() []*wire.MsgMixPairReq { p.mtx.Lock() defer p.mtx.Unlock() - p.removeConfirmedRuns() + p.removeConfirmedSessions() res := make([]*wire.MsgMixPairReq, 0, len(p.prs)) for _, pr := range p.prs { @@ -373,10 +369,8 @@ func (p *Pool) expireMessagesNow(height uint32) { } delete(p.sessions, sid) - for _, r := range ses.runs { - for hash := range r.hashes { - delete(p.pool, hash) - } + for hash := range ses.hashes { + delete(p.pool, hash) } } @@ -421,13 +415,13 @@ func (p *Pool) RemoveMessage(msg mixing.Message) { } } -// RemoveSession removes all non-PR messages from a completed or errored -// session. PR messages of a successful run (or rerun) are also removed. -func (p *Pool) RemoveSession(sid [32]byte, success bool) { +// RemoveSession removes the PRs and all session messages involving them from +// a completed session. PR messages of a successful session are also removed. +func (p *Pool) RemoveSession(sid [32]byte) { p.mtx.Lock() defer p.mtx.Unlock() - p.removeSession(sid, nil, success) + p.removeSession(sid, nil, true) } func (p *Pool) removeSession(sid [32]byte, txHash *chainhash.Hash, success bool) { @@ -438,17 +432,15 @@ func (p *Pool) removeSession(sid [32]byte, txHash *chainhash.Hash, success bool) // Delete PRs used to form final run var removePRs []chainhash.Hash - var lastRun *runstate if success { - lastRun = &ses.runs[len(ses.runs)-1] - removePRs = lastRun.prs + removePRs = ses.prs } if txHash != nil || success { if txHash == nil { // XXX: may be better to store this in the runstate as // a CM is received. - for h := range lastRun.hashes { + for h := range ses.hashes { if e, ok := p.pool[h]; ok && e.msgtype == msgtypeCM { cm := e.msg.(*wire.MsgMixConfirm) hash := cm.Mix.TxHash() @@ -463,14 +455,12 @@ func (p *Pool) removeSession(sid [32]byte, txHash *chainhash.Hash, success bool) } delete(p.sessions, sid) - for i, r := range ses.runs { - for hash := range r.hashes { - e, ok := p.pool[hash] - if ok { - log.Debugf("Removing session %x run %d %T %v by %x", - sid[:], i, e.msg, hash, e.msg.Pub()) - delete(p.pool, hash) - } + for hash := range ses.hashes { + e, ok := p.pool[hash] + if ok { + log.Debugf("Removing session %x %T %v by %x", + sid[:], e.msg, hash, e.msg.Pub()) + delete(p.pool, hash) } } @@ -482,31 +472,28 @@ func (p *Pool) removeSession(sid [32]byte, txHash *chainhash.Hash, success bool) } } -// RemoveConfirmedRuns removes all messages including pair requests from +// RemoveConfirmedSessions removes all messages including pair requests from // runs which ended in each peer sending a confirm mix message. -func (p *Pool) RemoveConfirmedRuns() { +func (p *Pool) RemoveConfirmedSessions() { p.mtx.Lock() defer p.mtx.Unlock() - p.removeConfirmedRuns() + p.removeConfirmedSessions() } -func (p *Pool) removeConfirmedRuns() { +func (p *Pool) removeConfirmedSessions() { for sid, ses := range p.sessions { - lastRun := &ses.runs[len(ses.runs)-1] - cmCount := lastRun.countFor(msgtypeCM) - if uint32(len(lastRun.prs)) != cmCount { + cmCount := ses.countFor(msgtypeCM) + if uint32(len(ses.prs)) != cmCount { continue } delete(p.sessions, sid) - for _, run := range ses.runs { - for hash := range run.hashes { - delete(p.pool, hash) - } + for hash := range ses.hashes { + delete(p.pool, hash) } - for _, hash := range lastRun.prs { + for _, hash := range ses.prs { delete(p.pool, hash) pr := p.prs[hash] if pr != nil { @@ -612,7 +599,6 @@ PRLoop: // appended to. Received messages are unsorted. type Received struct { Sid [32]byte - Run uint32 KEs []*wire.MsgMixKeyExchange CTs []*wire.MsgMixCiphertexts SRs []*wire.MsgMixSlotReserve @@ -636,9 +622,7 @@ type Received struct { // the secrets after each peer publishes their own revealed secrets. func (p *Pool) Receive(ctx context.Context, expectedMessages int, r *Received) error { sid := r.Sid - run := r.Run var bc *broadcast - var rs *runstate p.mtx.RLock() ses, ok := p.sessions[sid] @@ -647,11 +631,6 @@ func (p *Pool) Receive(ctx context.Context, expectedMessages int, r *Received) e return fmt.Errorf("unknown session %x", sid[:]) } bc = &ses.bc - if run >= uint32(len(ses.runs)) { - p.mtx.RUnlock() - return fmt.Errorf("unknown run %d", run) - } - rs = &ses.runs[run] nonNilSlices := 0 if r.KEs != nil { @@ -684,7 +663,7 @@ Loop: // Pool is locked for reads. Count if the total number of // expected messages have been received. received := 0 - for hash := range rs.hashes { + for hash := range ses.hashes { msgtype := p.pool[hash].msgtype switch { case msgtype == msgtypeKE && r.KEs != nil: @@ -736,7 +715,7 @@ Loop: } // Pool is locked for reads. Collect all of the messages. - for hash := range rs.hashes { + for hash := range ses.hashes { msg := p.pool[hash].msg switch msg := msg.(type) { case *wire.MsgMixKeyExchange: @@ -790,7 +769,7 @@ var zeroHash chainhash.Hash func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err error) { defer func() { if err == nil && len(accepted) == 0 { - // Duplicate message; don't log it again. + // Don't log duplicate messages or non-KE orphans. return } if log.Level() > slog.LevelDebug { @@ -803,8 +782,8 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err log.Debugf("Rejected message %T %v by %x: %v", msg, hash, msg.Pub(), err) default: - log.Debugf("Rejected message %T %v (session %x run %d) by %x: %v", - msg, hash, msg.Sid(), msg.GetRun(), msg.Pub(), err) + log.Debugf("Rejected message %T %v (session %x) by %x: %v", + msg, hash, msg.Sid(), msg.Pub(), err) } return } @@ -814,12 +793,16 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err case *wire.MsgMixPairReq: log.Debugf("Accepted message %T %v by %x", msg, hash, msg.Pub()) default: - log.Debugf("Accepted message %T %v (session %x run %d) by %x", - msg, hash, msg.Sid(), msg.GetRun(), msg.Pub()) + log.Debugf("Accepted message %T %v (session %x) by %x", + msg, hash, msg.Sid(), msg.Pub()) } } }() + if msg.GetRun() != 0 { + return nil, ruleError(fmt.Errorf("nonzero reruns are unsupported")) + } + hash := msg.Hash() if hash == zeroHash { return nil, fmt.Errorf("message of type %T has not been hashed", msg) @@ -938,7 +921,7 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err haveKE = true } } - // Save as an orphan if their KE is not (yet) known. + // Save as an orphan if their KE is not (yet) accepted. if !haveKE { orphansByID := p.orphansByID[*id] if _, ok := orphansByID[hash]; ok { @@ -1373,14 +1356,8 @@ func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *id sid := ke.SessionID ses := p.sessions[sid] - // Create a session for the first run-0 KE + // Create a session for the first KE if ses == nil { - if ke.Run != 0 { - err := fmt.Errorf("unknown session for run-%d KE", - ke.Run) - return nil, err - } - expiry := ^uint32(0) for i := range prs { prExpiry := prs[i].Expires() @@ -1390,8 +1367,9 @@ func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *id } ses = &session{ sid: sid, - runs: make([]runstate, 0, 4), + prs: ke.SeenPRs, expiry: expiry, + hashes: make(map[chainhash.Hash]struct{}), bc: broadcast{ch: make(chan struct{})}, } p.sessions[sid] = ses @@ -1408,25 +1386,7 @@ func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *id func (p *Pool) acceptEntry(msg mixing.Message, msgtype msgtype, hash *chainhash.Hash, id *[33]byte, ses *session) error { - run := msg.GetRun() - if run > uint32(len(ses.runs)) { - return ruleError(fmt.Errorf("message skips runs")) - } - - var rs *runstate - if msgtype == msgtypeKE && run == uint32(len(ses.runs)) { - // Add a runstate for the next run. - ses.runs = append(ses.runs, runstate{ - prs: msg.PrevMsgs(), - hashes: make(map[chainhash.Hash]struct{}), - }) - rs = &ses.runs[len(ses.runs)-1] - } else { - // Add to existing runstate - rs = &ses.runs[run] - } - - rs.hashes[*hash] = struct{}{} + ses.hashes[*hash] = struct{}{} e := entry{ hash: *hash, sid: ses.sid, @@ -1441,7 +1401,7 @@ func (p *Pool) acceptEntry(msg mixing.Message, msgtype msgtype, hash *chainhash. p.sessionsByTxHash[cm.Mix.TxHash()] = ses } - rs.incrementCountFor(msgtype) + ses.incrementCountFor(msgtype) ses.bc.signal() return nil diff --git a/mixing/mixpool/orphans_test.go b/mixing/mixpool/orphans_test.go index 4374349bd..fb6bc0151 100644 --- a/mixing/mixpool/orphans_test.go +++ b/mixing/mixpool/orphans_test.go @@ -67,7 +67,7 @@ func TestOrphans(t *testing.T) { prs := []*wire.MsgMixPairReq{pr} epoch := uint64(1704067200) sid := mixing.SortPRsForSession(prs, epoch) - ke1 := &wire.MsgMixKeyExchange{ + ke := &wire.MsgMixKeyExchange{ Identity: id, SeenPRs: []chainhash.Hash{ pr.Hash(), @@ -76,59 +76,30 @@ func TestOrphans(t *testing.T) { Epoch: epoch, Run: 0, } - err = mixing.SignMessage(ke1, priv) + err = mixing.SignMessage(ke, priv) if err != nil { t.Fatal(err) } - ke1.WriteHash(h) + ke.WriteHash(h) - ke2 := &wire.MsgMixKeyExchange{ - Identity: id, - SeenPRs: []chainhash.Hash{ - pr.Hash(), - }, - SessionID: sid, - Epoch: epoch, - Run: 1, - } - err = mixing.SignMessage(ke2, priv) - if err != nil { - t.Fatal(err) - } - ke2.WriteHash(h) - - fp1 := &wire.MsgMixFactoredPoly{ + fp := &wire.MsgMixFactoredPoly{ Identity: id, SessionID: sid, Run: 0, } - err = mixing.SignMessage(fp1, priv) + err = mixing.SignMessage(fp, priv) if err != nil { t.Fatal(err) } - fp1.WriteHash(h) - - fp2 := &wire.MsgMixFactoredPoly{ - Identity: id, - SessionID: sid, - Run: 1, - } - err = mixing.SignMessage(fp2, priv) - if err != nil { - t.Fatal(err) - } - fp2.WriteHash(h) + fp.WriteHash(h) t.Logf("pr %s", pr.Hash()) - t.Logf("ke1 %s", ke1.Hash()) - t.Logf("ke2 %s", ke2.Hash()) - t.Logf("fp1 %s", fp1.Hash()) - t.Logf("fp2 %s", fp2.Hash()) + t.Logf("ke %s", ke.Hash()) + t.Logf("fp %s", fp.Hash()) - // Create a pair request, several KEs, and later messages belong to - // the session and run increment for each KE. Test different - // combinations of acceptance order to test orphan processing of - // various message types. + // Create a pair request, KE, and later messages that belong to the + // session for each KE. Test different combinations of acceptance + // order to test orphan processing of various message types. type accept struct { desc string message mixing.Message @@ -140,7 +111,7 @@ func TestOrphans(t *testing.T) { // Accept KE, then PR, then FP 0: {{ desc: "accept KE before PR", - message: ke1, + message: ke, errors: true, errAs: new(MissingOwnPRError), accepted: nil, @@ -148,23 +119,23 @@ func TestOrphans(t *testing.T) { desc: "accept PR after KE; both should now process", message: pr, errors: false, - accepted: []mixing.Message{pr, ke1}, + accepted: []mixing.Message{pr, ke}, }, { desc: "accept future message in accepted KE session/run", - message: fp1, + message: fp, errors: false, // maybe later. - accepted: []mixing.Message{fp1}, + accepted: []mixing.Message{fp}, }}, // Accept FP, then KE, then PR 1: {{ desc: "accept FP first", - message: fp1, + message: fp, errors: false, accepted: nil, }, { desc: "accept KE", - message: ke1, + message: ke, errors: true, errAs: new(MissingOwnPRError), accepted: nil, @@ -172,64 +143,7 @@ func TestOrphans(t *testing.T) { desc: "accept PR; all should now be processed", message: pr, errors: false, - accepted: []mixing.Message{pr, ke1, fp1}, - }}, - - // Accept PR, then FP1, then FP2, then KE1, then KE2. - 2: {{ - desc: "accept PR first", - message: pr, - errors: false, - accepted: []mixing.Message{pr}, - }, { - desc: "accept FP1", - message: fp1, - errors: false, - accepted: nil, - }, { - desc: "accept FP2", - message: fp2, - errors: false, - accepted: nil, - }, { - desc: "accept KE1", - message: ke1, - errors: false, - accepted: []mixing.Message{ke1, fp1}, - }, { - desc: "accept KE2", - message: ke2, - errors: false, - accepted: []mixing.Message{ke2, fp2}, - }}, - - 3: {{ - desc: "accept FP1", - message: fp1, - errors: false, - accepted: nil, - }, { - desc: "accept FP2", - message: fp2, - errors: false, - accepted: nil, - }, { - desc: "accept KE1", - message: ke1, - errors: true, - errAs: new(MissingOwnPRError), - accepted: nil, - }, { - desc: "accept KE2", - message: ke2, - errors: true, - errAs: new(MissingOwnPRError), - accepted: nil, - }, { - desc: "accept PR last", - message: pr, - errors: false, - accepted: []mixing.Message{pr, ke1, ke2, fp1, fp2}, + accepted: []mixing.Message{pr, ke, fp}, }}, } @@ -249,8 +163,6 @@ func TestOrphans(t *testing.T) { t.Logf("orphans: %v", mp.orphans) t.Logf("orphansByID: %v", mp.orphansByID) t.Logf("pr: %v", pr) - t.Logf("ke2: %v", ke2) - t.Logf("fp2: %v", fp2) t.Errorf("test %d call %d %q: accepted lengths differ %d != %d", i, j, a.desc, len(accepted), len(a.accepted)) }