Skip to content

Commit

Permalink
stop incrementing runs, form new sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
jrick committed May 31, 2024
1 parent f7faa1d commit db25943
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 253 deletions.
4 changes: 2 additions & 2 deletions internal/rpcserver/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,9 @@ type MixPooler interface {
// Message searches the mixing pool for a message by its hash.
Message(query *chainhash.Hash) (mixing.Message, error)

// RemoveConfirmedRuns removes all messages including pair requests
// RemoveConfirmedSessions removes all messages including pair requests
// from runs which ended in each peer sending a confirm mix message.
RemoveConfirmedRuns()
RemoveConfirmedSessions()
}

// TxIndexer provides an interface for retrieving details for a given
Expand Down
2 changes: 1 addition & 1 deletion internal/rpcserver/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2607,7 +2607,7 @@ func handleGetMixMessage(_ context.Context, s *Server, cmd interface{}) (interfa
func handleGetMixPairRequests(_ context.Context, s *Server, _ interface{}) (interface{}, error) {
mp := s.cfg.MixPooler

mp.RemoveConfirmedRuns() // XXX: a bit hacky to do this here
mp.RemoveConfirmedSessions() // XXX: a bit hacky to do this here
prs := mp.MixPRs(nil)

buf := new(strings.Builder)
Expand Down
7 changes: 3 additions & 4 deletions mixing/mixclient/blame.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (e blamedIdentities) String() string {
}

func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) {
c.logf("Blaming for sid=%x run=%d", sesRun.sid[:], sesRun.run)
c.logf("Blaming for sid=%x", sesRun.sid[:])

mp := c.mixpool
prs := sesRun.prs
Expand Down Expand Up @@ -79,7 +79,6 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) {

// Receive currently-revealed secrets
rcv := new(mixpool.Received)
rcv.Run = sesRun.run
rcv.Sid = sesRun.sid
rcv.RSs = make([]*wire.MsgMixSecrets, 0, len(sesRun.prs))
_ = mp.Receive(ctx, 0, rcv)
Expand Down Expand Up @@ -193,7 +192,7 @@ KELoop:

// Recover or initialize PRNG from seed and the last run that
// caused secrets to be generated.
p.prng = chacha20prng.New(p.rs.Seed[:], sesRun.prngRun)
p.prng = chacha20prng.New(p.rs.Seed[:], 0)

// Recover derived key exchange from PRNG.
p.kx, err = mixing.NewKX(p.prng)
Expand Down Expand Up @@ -320,7 +319,7 @@ SRLoop:
revealed.Ciphertexts = append(revealed.Ciphertexts, ct[p.myVk])
}
sharedSecrets, err := p.kx.SharedSecrets(revealed,
sesRun.sid[:], sesRun.run, sesRun.mcounts)
sesRun.sid[:], 0, sesRun.mcounts)
var decapErr *mixing.DecapsulateError
if errors.As(err, &decapErr) {
submittingID := p.id
Expand Down
71 changes: 24 additions & 47 deletions mixing/mixclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,10 @@ type pairedSessions struct {

type sessionRun struct {
sid [32]byte
run uint32
mtot uint32

// Whether this run must generate fresh KX keys, SR/DC messages.
// prngRun records the run (used as PRNG nonce) for the last run where
// a new PRNG and keys were generated.
freshGen bool
prngRun uint32

deadlines

Expand Down Expand Up @@ -578,7 +574,7 @@ func (c *Client) epochTicker(ctx context.Context) error {
// results in "expired PR" errors.
// Ideally, we would behave like dcrd and only remove sessions that have
// mined mix transactions or are otherwise double spent in a block.
c.mixpool.RemoveConfirmedRuns()
c.mixpool.RemoveConfirmedSessions()
c.expireMessages()

for _, p := range c.pairings {
Expand Down Expand Up @@ -768,7 +764,7 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir
time.Sleep(10 * time.Second)
c.logf("sid=%x removing mixed session completed with transaction %v",
mixedSession.sid[:], mixedSession.cj.txHash)
c.mixpool.RemoveSession(mixedSession.sid, true)
c.mixpool.RemoveSession(mixedSession.sid)
}()
}
if len(unmixedPeers) == 0 {
Expand Down Expand Up @@ -822,10 +818,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir

sesRun := sessionRun{
sid: sid,
run: 0,
prs: prs,
freshGen: true,
prngRun: 0,
deadlines: d,
mcounts: make([]uint32, 0, len(prs)),
}
Expand Down Expand Up @@ -865,12 +859,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir
}
currentRun.mtot = m

action := "created"
if currentRun.run != 0 {
action = "rerunning"
}
sesLog.logf("%s session for pairid=%x from %d total %d local PRs %s",
action, ps.pairing, len(prHashes), localPeerCount, prHashes)
sesLog.logf("created session for pairid=%x from %d total %d local PRs %s",
ps.pairing, len(prHashes), localPeerCount, prHashes)

if localPeerCount == 0 {
sesLog.logf("no more local peers")
Expand Down Expand Up @@ -903,10 +893,8 @@ func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wir

rerun = &sessionRun{
sid: altses.sid,
run: 0,
prs: altses.prs,
freshGen: false,
prngRun: 0,
deadlines: d,
}
continue
Expand Down Expand Up @@ -956,7 +944,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)

mp := c.wallet.Mixpool()
sesRun := &ps.runs[len(ps.runs)-1]
run := sesRun.run
prs := sesRun.prs

d := &sesRun.deadlines
Expand Down Expand Up @@ -984,7 +971,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
if err != nil {
return err
}
p.prng = chacha20prng.New(p.prngSeed[:], sesRun.prngRun)
p.prng = chacha20prng.New(p.prngSeed[:], 0)

// Generate fresh keys from this run's PRNG
p.kx, err = mixing.NewKX(p.prng)
Expand Down Expand Up @@ -1026,14 +1013,14 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
for i := range p.srMsg {
srMsgBytes[i] = p.srMsg[i].Bytes()
}
rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, run,
rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, 0,
*p.prngSeed, srMsgBytes, p.dcMsg)
c.blake256HasherMu.Lock()
commitment := rs.Commitment(c.blake256Hasher)
c.blake256HasherMu.Unlock()
ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed())
pqPub := *p.kx.PQPublicKey
ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, run,
ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, 0,
uint32(identityIndices[*p.id]), ecdhPub, pqPub, commitment,
seenPRs)

Expand Down Expand Up @@ -1081,21 +1068,20 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
var kes []*wire.MsgMixKeyExchange
recvKEs := func(sesRun *sessionRun) (kes []*wire.MsgMixKeyExchange, err error) {
rcv := new(mixpool.Received)
rcv.Run = sesRun.run
rcv.Sid = sesRun.sid
rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs))
ctx, cancel := context.WithDeadline(ctx, d.recvKE)
defer cancel()
err = mp.Receive(ctx, len(sesRun.prs), rcv)
if ctx.Err() != nil {
err = fmt.Errorf("session %x run-%d KE receive context cancelled: %w",
sesRun.sid[:], sesRun.run, ctx.Err())
err = fmt.Errorf("session %x KE receive context cancelled: %w",
sesRun.sid[:], ctx.Err())
}
return rcv.KEs, err
}

switch {
case run == 0:
case !*madePairing:
// Receive KEs for the last attempted session. Local
// peers may have been modified (new keys generated, and myVk
// indexes changed) if this is a recreated session, and we
Expand Down Expand Up @@ -1125,7 +1111,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
return c.alternateSession(ps.pairing, sesRun.prs, d)

default:
// Receive KEs only for the agreed-upon session.
kes, err = recvKEs(sesRun)
if err != nil {
return err
Expand All @@ -1139,7 +1124,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Remove paired local peers from waiting pairing.
if run == 0 {
if !*madePairing {
c.mu.Lock()
if waiting := c.pairings[string(ps.pairing)]; waiting != nil {
for id := range ps.localPeers {
Expand All @@ -1151,10 +1136,8 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}
c.mu.Unlock()

if !*madePairing {
c.pairingWG.Done()
*madePairing = true
}
*madePairing = true
c.pairingWG.Done()
}

sort.Slice(kes, func(i, j int) bool {
Expand Down Expand Up @@ -1204,7 +1187,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Send ciphertext messages
ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, run, pqct, seenKEs)
ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, 0, pqct, seenKEs)
p.ct = ct
c.testHook(hookBeforePeerCTPublish, sesRun, p)
return p.signAndSubmit(ct)
Expand All @@ -1213,7 +1196,6 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
// Receive all ciphertext messages
rcv := new(mixpool.Received)
rcv.Sid = sesRun.sid
rcv.Run = run
rcv.KEs = nil
rcv.CTs = make([]*wire.MsgMixCiphertexts, 0, len(prs))
rcvCtx, rcvCtxCancel := context.WithDeadline(ctx, d.recvCT)
Expand Down Expand Up @@ -1265,7 +1247,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Derive shared secret keys
shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], run, sesRun.mcounts)
shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], 0, sesRun.mcounts)
if err != nil {
p.triggeredBlame = true
return errTriggeredBlame
Expand All @@ -1283,7 +1265,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)

// Broadcast message commitment and exponential DC-mix vectors for slot
// reservations.
sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, run, srMixBytes, seenCTs)
sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, 0, srMixBytes, seenCTs)
p.sr = sr
c.testHook(hookBeforePeerSRPublish, sesRun, p)
return p.signAndSubmit(sr)
Expand Down Expand Up @@ -1365,7 +1347,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Broadcast XOR DC-net vectors.
dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, run, p.dcNet, seenSRs)
dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, 0, p.dcNet, seenSRs)
p.dc = dc
c.testHook(hookBeforePeerDCPublish, sesRun, p)
return p.signAndSubmit(dc)
Expand Down Expand Up @@ -1445,7 +1427,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool)
}

// Broadcast partially signed mix tx
cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, run,
cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, 0,
p.coinjoin.Tx().Copy(), seenDCs)
p.cm = cm
return p.signAndSubmit(cm)
Expand Down Expand Up @@ -1555,7 +1537,7 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash,
}
err = c.forLocalPeers(ctx, sesRun, func(p *peer) error {
fp := wire.NewMsgMixFactoredPoly(*p.id, sesRun.sid,
sesRun.run, rootBytes, seenSRs)
0, rootBytes, seenSRs)
return p.signAndSubmit(fp)
})
return roots, err
Expand All @@ -1568,7 +1550,6 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash,
expectedMessages := 1
rcv := &mixpool.Received{
Sid: sesRun.sid,
Run: sesRun.run,
FPs: make([]*wire.MsgMixFactoredPoly, 0, sesRun.mtot),
}
roots := make([]*big.Int, 0, len(a)-1)
Expand Down Expand Up @@ -1924,21 +1905,17 @@ func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities, revealedSecrets
d := prevRun.deadlines
d.restart()

unixEpoch := prevRun.epoch.Unix()
sid := mixing.SortPRsForSession(prs, uint64(unixEpoch))

// mtot, peers, mcounts are all recalculated from the prs before
// calling run()
nextRun := &sessionRun{
sid: prevRun.sid,
run: prevRun.run + 1,
sid: sid,
freshGen: revealedSecrets,
prs: prs,
deadlines: d,
}
if revealedSecrets {
nextRun.freshGen = true
nextRun.prngRun = nextRun.run
} else {
nextRun.freshGen = false
nextRun.prngRun = prevRun.prngRun
}
return nextRun
}

Expand Down
27 changes: 20 additions & 7 deletions mixing/mixclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestHonest(t *testing.T) {
select {
case <-ctx.Done():
return
case <-time.After(200 * time.Millisecond):
case <-time.After(500 * time.Millisecond):
}
}
}()
Expand Down Expand Up @@ -367,7 +367,7 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) {
select {
case <-ctx.Done():
return
case <-time.After(200 * time.Millisecond):
case <-time.After(1000 * time.Millisecond):
}
}
}()
Expand Down Expand Up @@ -395,7 +395,7 @@ func testDisruption(t *testing.T, misbehavingID *identity, h hook, f hookFunc) {
func TestCTDisruption(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
Expand All @@ -410,16 +410,23 @@ func TestCTDisruption(t *testing.T) {
func TestCTLength(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: sending too few ciphertexts", p.id[:])
misbehavingID = *p.id
p.ct.Ciphertexts = p.ct.Ciphertexts[:len(p.ct.Ciphertexts)-1]
})

misbehavingID = identity{}
testDisruption(t, &misbehavingID, hookBeforePeerCTPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: sending too many ciphertexts", p.id[:])
Expand All @@ -431,7 +438,10 @@ func TestCTLength(t *testing.T) {
func TestSRDisruption(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerSRPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: flipping SR bit", p.id[:])
Expand All @@ -443,7 +453,10 @@ func TestSRDisruption(t *testing.T) {
func TestDCDisruption(t *testing.T) {
var misbehavingID identity
testDisruption(t, &misbehavingID, hookBeforePeerDCPublish, func(c *Client, s *sessionRun, p *peer) {
if s.run != 0 || p.myVk != 0 {
if p.myVk != 0 {
return
}
if misbehavingID != [33]byte{} {
return
}
t.Logf("malicious peer %x: flipping DC bit", p.id[:])
Expand Down
Loading

0 comments on commit db25943

Please sign in to comment.