Skip to content

Commit

Permalink
mixing: Create new sessions over incrementing runs
Browse files Browse the repository at this point in the history
This simplifies the code significantly while simultaneously making mixpool
acceptance checks stricter in ruruns after blame due to the session validation
and PR inclusion rules that were only being performed during run 0.

This avoids an issue where peers who have been excluded in a later rerun
continue to submit messages in the rerun.  These messages were not being
correctly rejected by mixpool, and could be improperly received by clients.
Rejecting these messages would have required wallets to provide additional
hints to the mixpool about which peers are still in the rerun, and this
information would be unavailable to dcrd mixpools entirely.

Instead, a new run-0 session is formed with a subset of the original peers,
and all of the existing run-0 validation would continue to be executed for the
rerun.  mixpools no longer understand the difference between reformed sessions
and reruns, and will refuse to accept any non-run-0 message.

In the future, this may also be useful to observe new rerun sessions that
other peers have tried to create that differ from our own, which will be
useful first step in debugging why these sets differ.

This is technically a breaking change that will stop rerun mixes for older
wallets who do not also create the same new sessions, but this only affects
reruns after blame assignment and original run-0 sessions will continue to
operate properly if all peers are honest and behaving.
  • Loading branch information
jrick committed Jun 3, 2024
1 parent fca1a59 commit 184cc8a
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 251 deletions.
7 changes: 3 additions & 4 deletions mixing/mixclient/blame.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (e blamedIdentities) String() string {
}

func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) {
c.logf("Blaming for sid=%x run=%d", sesRun.sid[:], sesRun.run)
c.logf("Blaming for sid=%x", sesRun.sid[:])

mp := c.mixpool
prs := sesRun.prs
Expand Down Expand Up @@ -79,7 +79,6 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) {

// Receive currently-revealed secrets
rcv := new(mixpool.Received)
rcv.Run = sesRun.run
rcv.Sid = sesRun.sid
rcv.RSs = make([]*wire.MsgMixSecrets, 0, len(sesRun.prs))
_ = mp.Receive(ctx, 0, rcv)
Expand Down Expand Up @@ -193,7 +192,7 @@ KELoop:

// Recover or initialize PRNG from seed and the last run that
// caused secrets to be generated.
p.prng = chacha20prng.New(p.rs.Seed[:], sesRun.prngRun)
p.prng = chacha20prng.New(p.rs.Seed[:], 0)

// Recover derived key exchange from PRNG.
p.kx, err = mixing.NewKX(p.prng)
Expand Down Expand Up @@ -320,7 +319,7 @@ SRLoop:
revealed.Ciphertexts = append(revealed.Ciphertexts, ct[p.myVk])
}
sharedSecrets, err := p.kx.SharedSecrets(revealed,
sesRun.sid[:], sesRun.run, sesRun.mcounts)
sesRun.sid[:], 0, sesRun.mcounts)
var decapErr *mixing.DecapsulateError
if errors.As(err, &decapErr) {
submittingID := p.id
Expand Down
71 changes: 24 additions & 47 deletions mixing/mixclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,10 @@ type pairedSessions struct {

type sessionRun struct {
sid [32]byte
run uint32
mtot uint32

// Whether this run must generate fresh KX keys, SR/DC messages.
// prngRun records the run (used as PRNG nonce) for the last run where
// a new PRNG and keys were generated.
freshGen bool
prngRun uint32

deadlines

Expand Down Expand Up @@ -578,7 +574,7 @@ func (c *Client) epochTicker(ctx context.Context) error {
// results in "expired PR" errors.
// Ideally, we would behave like dcrd and only remove sessions that have
// mined mix transactions or are otherwise double spent in a block.
c.mixpool.RemoveConfirmedRuns()
c.mixpool.RemoveConfirmedSessions()
c.expireMessages()

for _, p := range c.pairings {
Expand Down Expand Up @@ -768,7 +764,7 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir
time.Sleep(10 * time.Second)
c.logf("sid=%x removing mixed session completed with transaction %v",
mixedSession.sid[:], mixedSession.cj.txHash)
c.mixpool.RemoveSession(mixedSession.sid, true)
c.mixpool.RemoveSession(mixedSession.sid)
}()
}
if len(unmixedPeers) == 0 {
Expand Down Expand Up @@ -822,10 +818,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir

sesRun := sessionRun{
sid: sid,
run: 0,
prs: prs,
freshGen: true,
prngRun: 0,
deadlines: d,
mcounts: make([]uint32, 0, len(prs)),
}
Expand Down Expand Up @@ -865,12 +859,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir
}
currentRun.mtot = m

action := "created"
if currentRun.run != 0 {
action = "rerunning"
}
sesLog.logf("%s session for pairid=%x from %d total %d local PRs %s",
action, ps.pairing, len(prHashes), localPeerCount, prHashes)
sesLog.logf("created session for pairid=%x from %d total %d local PRs %s",
ps.pairing, len(prHashes), localPeerCount, prHashes)

if localPeerCount == 0 {
sesLog.logf("no more local peers")
Expand Down Expand Up @@ -903,10 +893,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir

rerun = &sessionRun{
sid: altses.sid,
run: 0,
prs: altses.prs,
freshGen: false,
prngRun: 0,
deadlines: d,
}
continue
Expand Down Expand Up @@ -956,7 +944,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)

mp := c.wallet.Mixpool()
sesRun := &ps.runs[len(ps.runs)-1]
run := sesRun.run
prs := sesRun.prs

d := &sesRun.deadlines
Expand Down Expand Up @@ -984,7 +971,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
if err != nil {
return err
}
p.prng = chacha20prng.New(p.prngSeed[:], sesRun.prngRun)
p.prng = chacha20prng.New(p.prngSeed[:], 0)

// Generate fresh keys from this run's PRNG
p.kx, err = mixing.NewKX(p.prng)
Expand Down Expand Up @@ -1026,14 +1013,14 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
for i := range p.srMsg {
srMsgBytes[i] = p.srMsg[i].Bytes()
}
rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, run,
rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, 0,
*p.prngSeed, srMsgBytes, p.dcMsg)
c.blake256HasherMu.Lock()
commitment := rs.Commitment(c.blake256Hasher)
c.blake256HasherMu.Unlock()
ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed())
pqPub := *p.kx.PQPublicKey
ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, run,
ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, 0,
uint32(identityIndices[*p.id]), ecdhPub, pqPub, commitment,
seenPRs)

Expand Down Expand Up @@ -1081,21 +1068,20 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
var kes []*wire.MsgMixKeyExchange
recvKEs := func(sesRun *sessionRun) (kes []*wire.MsgMixKeyExchange, err error) {
rcv := new(mixpool.Received)
rcv.Run = sesRun.run
rcv.Sid = sesRun.sid
rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs))
ctx, cancel := context.WithDeadline(ctx, d.recvKE)
defer cancel()
err = mp.Receive(ctx, len(sesRun.prs), rcv)
if ctx.Err() != nil {
err = fmt.Errorf("session %x run-%d KE receive context cancelled: %w",
sesRun.sid[:], sesRun.run, ctx.Err())
err = fmt.Errorf("session %x KE receive context cancelled: %w",
sesRun.sid[:], ctx.Err())
}
return rcv.KEs, err
}

switch {
case run == 0:
case !*madePairing:
// Receive KEs for the last attempted session. Local
// peers may have been modified (new keys generated, and myVk
// indexes changed) if this is a recreated session, and we
Expand Down Expand Up @@ -1125,7 +1111,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
return c.alternateSession(ps.pairing, sesRun.prs, d)

default:
// Receive KEs only for the agreed-upon session.
kes, err = recvKEs(sesRun)
if err != nil {
return err
Expand All @@ -1139,7 +1124,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Remove paired local peers from waiting pairing.
if run == 0 {
if !*madePairing {
c.mu.Lock()
if waiting := c.pairings[string(ps.pairing)]; waiting != nil {
for id := range ps.localPeers {
Expand All @@ -1151,10 +1136,8 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}
c.mu.Unlock()

if !*madePairing {
c.pairingWG.Done()
*madePairing = true
}
*madePairing = true
c.pairingWG.Done()
}

sort.Slice(kes, func(i, j int) bool {
Expand Down Expand Up @@ -1204,7 +1187,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Send ciphertext messages
ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, run, pqct, seenKEs)
ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, 0, pqct, seenKEs)
p.ct = ct
c.testHook(hookBeforePeerCTPublish, sesRun, p)
return p.signAndSubmit(ct)
Expand All @@ -1213,7 +1196,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
// Receive all ciphertext messages
rcv := new(mixpool.Received)
rcv.Sid = sesRun.sid
rcv.Run = run
rcv.KEs = nil
rcv.CTs = make([]*wire.MsgMixCiphertexts, 0, len(prs))
rcvCtx, rcvCtxCancel := context.WithDeadline(ctx, d.recvCT)
Expand Down Expand Up @@ -1265,7 +1247,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Derive shared secret keys
shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], run, sesRun.mcounts)
shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], 0, sesRun.mcounts)
if err != nil {
p.triggeredBlame = true
return errTriggeredBlame
Expand All @@ -1283,7 +1265,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)

// Broadcast message commitment and exponential DC-mix vectors for slot
// reservations.
sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, run, srMixBytes, seenCTs)
sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, 0, srMixBytes, seenCTs)
p.sr = sr
c.testHook(hookBeforePeerSRPublish, sesRun, p)
return p.signAndSubmit(sr)
Expand Down Expand Up @@ -1365,7 +1347,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Broadcast XOR DC-net vectors.
dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, run, p.dcNet, seenSRs)
dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, 0, p.dcNet, seenSRs)
p.dc = dc
c.testHook(hookBeforePeerDCPublish, sesRun, p)
return p.signAndSubmit(dc)
Expand Down Expand Up @@ -1445,7 +1427,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Broadcast partially signed mix tx
cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, run,
cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, 0,
p.coinjoin.Tx().Copy(), seenDCs)
p.cm = cm
return p.signAndSubmit(cm)
Expand Down Expand Up @@ -1555,7 +1537,7 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash,
}
err = c.forLocalPeers(ctx, sesRun, func(p *peer) error {
fp := wire.NewMsgMixFactoredPoly(*p.id, sesRun.sid,
sesRun.run, rootBytes, seenSRs)
0, rootBytes, seenSRs)
return p.signAndSubmit(fp)
})
return roots, err
Expand All @@ -1568,7 +1550,6 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash,
expectedMessages := 1
rcv := &mixpool.Received{
Sid: sesRun.sid,
Run: sesRun.run,
FPs: make([]*wire.MsgMixFactoredPoly, 0, sesRun.mtot),
}
roots := make([]*big.Int, 0, len(a)-1)
Expand Down Expand Up @@ -1924,21 +1905,17 @@ func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities, revealedSecrets
d := prevRun.deadlines
d.restart()

unixEpoch := prevRun.epoch.Unix()
sid := mixing.SortPRsForSession(prs, uint64(unixEpoch))

// mtot, peers, mcounts are all recalculated from the prs before
// calling run()
nextRun := &sessionRun{
sid: prevRun.sid,
run: prevRun.run + 1,
sid: sid,
freshGen: revealedSecrets,
prs: prs,
deadlines: d,
}
if revealedSecrets {
nextRun.freshGen = true
nextRun.prngRun = nextRun.run
} else {
nextRun.freshGen = false
nextRun.prngRun = prevRun.prngRun
}
return nextRun
}

Expand Down
27 changes: 20 additions & 7 deletions mixing/mixclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestHonest(t *testing.T) {
select {
case <-ctx.Done():
return
case <-time.After(200 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
}
}
}()
Expand Down Expand Up @@ -367,7 +367,7 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) {
select {
case <-ctx.Done():
return
case <-time.After(200 * time.Millisecond):
case <-time.After(1000 * time.Millisecond):
}
}
}()
Expand Down Expand Up @@ -395,7 +395,7 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) {
func TestCTDisruption(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
Expand All @@ -410,16 +410,23 @@ func TestCTDisruption(t *testing.T) {
func TestCTLength(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: sending too few ciphertexts", p.id[:])
misbehavingID = *p.id
p.ct.Ciphertexts = p.ct.Ciphertexts[:len(p.ct.Ciphertexts)-1]
})

misbehavingID = identity{}
testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: sending too many ciphertexts", p.id[:])
Expand All @@ -431,7 +438,10 @@ func TestCTLength(t *testing.T) {
func TestSRDisruption(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerSRPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: flipping SR bit", p.id[:])
Expand All @@ -443,7 +453,10 @@ func TestSRDisruption(t *testing.T) {
func TestDCDisruption(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerDCPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: flipping DC bit", p.id[:])
Expand Down
Loading

0 comments on commit 184cc8a

Please sign in to comment.