Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
jrick committed Sep 15, 2023
1 parent 9e3e6e9 commit 1ac206c
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 42 deletions.
3 changes: 2 additions & 1 deletion internal/netsync/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,8 @@ func (m *SyncManager) handleMixMsg(mmsg *mixMsg) {

accepted, err := m.cfg.MixPool.AcceptMessage(mmsg.msg)
if err != nil {
log.Errorf("Failed to process mixing message: %v", mmsg.msg)
log.Errorf("Failed to process %T mixing message %v: %v",
mmsg.msg, mmsg.msg.Hash(), err)
return
}
if accepted == nil {
Expand Down
2 changes: 1 addition & 1 deletion mixing/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"math/big"
)

// F is the field prime 2**127 - 1.
// FieldPrime is the field prime 2**127 - 1.
var F *big.Int

func init() {
Expand Down
178 changes: 139 additions & 39 deletions mixing/mixpool/mixpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ package mixpool
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"sort"
"strings"
"sync"
"time"

Expand All @@ -21,8 +23,12 @@ import (
"github.com/decred/dcrd/mixing/utxoproof"
"github.com/decred/dcrd/txscript/v4/stdscript"
"github.com/decred/dcrd/wire"

"github.com/davecgh/go-spew/spew"
)

var _ = spew.Dump // XXX

const minconf = 2
const feeRate = 0.0001e8

Expand Down Expand Up @@ -66,14 +72,12 @@ type runstate struct {
}

type broadcast struct {
funcs []func()
ch chan struct{}
mu sync.Mutex
ch chan struct{}
mu sync.Mutex
}

// wait returns the wait channel that is closed when a broadcast is made due to
// receiving the expected number of messages for a session.
// Waiters must acquire the pool lock before reading messages.
// wait returns the wait channel that is closed whenever a message is received
// for a session. Waiters must acquire the pool lock before reading messages.
func (b *broadcast) wait() <-chan struct{} {
b.mu.Lock()
ch := b.ch
Expand All @@ -94,6 +98,7 @@ func (b *broadcast) signal() {
type Pool struct {
mtx sync.RWMutex
prs map[chainhash.Hash]*wire.MsgMixPR
outPoints map[wire.OutPoint]chainhash.Hash
pool map[chainhash.Hash]entry
messagesByIdentity map[idPubKey][]chainhash.Hash
sessions map[[32]byte]*session
Expand Down Expand Up @@ -138,6 +143,7 @@ type BlockChain interface {
func NewPool(blockchain BlockChain) *Pool {
pool := &Pool{
prs: make(map[chainhash.Hash]*wire.MsgMixPR),
outPoints: make(map[wire.OutPoint]chainhash.Hash),
pool: make(map[chainhash.Hash]entry),
messagesByIdentity: make(map[idPubKey][]chainhash.Hash),
sessions: make(map[[32]byte]*session),
Expand Down Expand Up @@ -254,7 +260,30 @@ func (p *Pool) CompatiblePRs(pairing []byte) []*wire.MsgMixPR {
res = append(res, pr)
}
}
return res

// Sort by decreasing expiries and remove any PRs double spending an
// output with an earlier expiry.
sort.Slice(res, func(i, j int) bool {
return res[i].Expiry >= res[j].Expiry
})
seen := make(map[wire.OutPoint]int64)
for i, pr := range res {
for _, utxo := range pr.UTXOs {
prevExpiry, ok := seen[utxo.OutPoint]
if !ok {
seen[utxo.OutPoint] = pr.Expiry
} else if pr.Expiry < prevExpiry {
res[i] = nil
}
}
}
filtered := res[:0]
for i := range res {
if res[i] != nil {
filtered = append(filtered, res[i])
}
}
return filtered
}

// ExpireMessages removes all messages and sessions that indicate an expiry
Expand Down Expand Up @@ -288,9 +317,9 @@ func (p *Pool) ExpireMessages(height int64) {
}
}

// RemoveSession removes all messages from a completed session. PR messages
// should be removed if the session was successful.
func (p *Pool) RemoveSession(sid [32]byte, removePRs bool) {
// RemoveSession removes all non-PR messages from a completed or errored
// session. PR messages of a successful run (or rerun) must also be removed.
func (p *Pool) RemoveSession(sid [32]byte, removePRs []*wire.MsgMixPR) {
p.mtx.Lock()
defer p.mtx.Unlock()

Expand All @@ -303,11 +332,14 @@ func (p *Pool) RemoveSession(sid [32]byte, removePRs bool) {
for _, r := range ses.runs {
for hash := range r.hashes {
delete(p.pool, hash)
if removePRs {
delete(p.prs, hash)
}
}
}

for _, pr := range removePRs {
hash := pr.Hash()
delete(p.pool, hash)
delete(p.prs, hash)
}
}

// RemoveRun removes all messages from a failed run in a mix session.
Expand Down Expand Up @@ -337,6 +369,7 @@ func (p *Pool) ReceiveKEs(pairing []byte) []*wire.MsgMixKE {
defer p.mtx.RUnlock()

var kes []*wire.MsgMixKE
Entries:
for _, e := range p.pool {
ke, ok := e.msg.(*wire.MsgMixKE)
if !ok {
Expand All @@ -345,16 +378,17 @@ func (p *Pool) ReceiveKEs(pairing []byte) []*wire.MsgMixKE {
for _, prHash := range ke.SeenPRs {
pr := p.prs[prHash]
if pr == nil {
continue
continue Entries
}
prPairing, err := pr.Pairing()
if err != nil {
continue
continue Entries
}
if bytes.Equal(pairing, prPairing) {
kes = append(kes, ke)
if !bytes.Equal(pairing, prPairing) {
continue Entries
}
}
kes = append(kes, ke)
}

return kes
Expand All @@ -377,7 +411,7 @@ type Received struct {
// Receive returns messages matching a session, run, and message type, waiting
// until all described messages have been received, or earlier with the
// messages received so far if the context is cancelled before this point.
func (p *Pool) Receive(ctx context.Context, r *Received) error {
func (p *Pool) Receive(ctx context.Context, expectedMessages int, r *Received) error {
sid := r.Sid
run := r.Run
var bc *broadcast
Expand All @@ -391,24 +425,54 @@ func (p *Pool) Receive(ctx context.Context, r *Received) error {
return fmt.Errorf("unknown session")
}
bc = &ses.bc
p.mtx.RUnlock()

select {
case <-ctx.Done():
// Set error to be returned, but still collect received
// messages
err = ctx.Err()
case <-bc.wait():
if run >= uint32(len(ses.runs)) {
p.mtx.RUnlock()
return fmt.Errorf("unknown run")
}
rs = &ses.runs[run]

p.mtx.RLock()
defer p.mtx.RUnlock()
for {
// Pool is locked. Count if the total number of expected
// messages have been received.
received := 0
for hash := range rs.hashes {
msgtype := p.pool[hash].msgtype
switch {
case msgtype == msgtypeKE && r.KEs != nil:
received++
case msgtype == msgtypeCT && r.CTs != nil:
received++
case msgtype == msgtypeSR && r.SRs != nil:
received++
case msgtype == msgtypeDC && r.DCs != nil:
received++
case msgtype == msgtypeCM && r.CMs != nil:
received++
case msgtype == msgtypeRS && r.RSs != nil:
received++
}
}
if received >= expectedMessages {
break
}

if run >= uint32(len(ses.runs)) {
return fmt.Errorf("unknown run")
// Unlock while waiting for the broadcast channel.
p.mtx.RUnlock()

select {
case <-ctx.Done():
// Set error to be returned, but still lock the pool
// and collect received messages.
err = ctx.Err()
p.mtx.RLock()
break
case <-bc.wait():
}

p.mtx.RLock()
}

rs = &ses.runs[run]
// Pool is locked. Collect all of the messages.
for hash := range rs.hashes {
msg := p.pool[hash].msg
switch msg := msg.(type) {
Expand Down Expand Up @@ -439,6 +503,7 @@ func (p *Pool) Receive(ctx context.Context, r *Received) error {
}
}

p.mtx.RUnlock()
return err
}

Expand All @@ -462,6 +527,11 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted mixing.Message, err e

// Require message to be signed by the presented identity.
if !mixing.VerifyMessageSignature(msg) {
p.mtx.Lock()
s := new(strings.Builder)
msg.WriteSigned(hex.NewEncoder(s))
log.Info(s.String())
p.mtx.Unlock()
return nil, fmt.Errorf("invalid message signature")
}
id := (*idPubKey)(msg.GetIdentity())
Expand Down Expand Up @@ -499,6 +569,11 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted mixing.Message, err e
case *wire.MsgMixDC:
msgtype = msgtypeDC
case *wire.MsgMixCM:
p.mtx.Lock()
s := new(strings.Builder)
msg.WriteSigned(hex.NewEncoder(s))
log.Debug(s.String())
p.mtx.Unlock()
msgtype = msgtypeCM
case *wire.MsgMixRS:
msgtype = msgtypeRS
Expand Down Expand Up @@ -615,8 +690,26 @@ func (p *Pool) acceptPR(pr *wire.MsgMixPR, hash *chainhash.Hash, id *idPubKey) e
return fmt.Errorf("identity reused for a PR message")
}

// Only accept PRs that double spend outpoints if they expire later
// than existing PRs. Otherwise, reject this PR message.
for i := range pr.UTXOs {
otherPRHash := p.outPoints[pr.UTXOs[i].OutPoint]
otherPR, ok := p.prs[otherPRHash]
if !ok {
continue
}
if otherPR.Expiry >= pr.Expiry {
return fmt.Errorf("PR double spends outpoints of " +
"already-accepted PR message without " +
"increasing expiry")
}
}

// Accept the PR
p.prs[*hash] = pr
for i := range pr.UTXOs {
p.outPoints[pr.UTXOs[i].OutPoint] = *hash
}
p.messagesByIdentity[*id] = append(make([]chainhash.Hash, 0, 16), *hash)

return nil
Expand All @@ -634,12 +727,14 @@ func (p *Pool) checkUTXOs(pr *wire.MsgMixPR) error {
if err != nil {
return err
}
if entry.IsSpent() {
return fmt.Errorf("output is not unspent")
if entry == nil || entry.IsSpent() {
return fmt.Errorf("output %v is not unspent",
&utxo.OutPoint)
}
height := entry.BlockHeight()
if !confirmed(minconf, height, curHeight) {
return fmt.Errorf("output is unconfirmed")
return fmt.Errorf("output %v is unconfirmed",
&utxo.OutPoint)
}

// Check proof of key ownership and ability to sign coinjoin
Expand Down Expand Up @@ -757,9 +852,7 @@ func (p *Pool) acceptKE(ke *wire.MsgMixKE, hash *chainhash.Hash, id *idPubKey) e
}

expiry := int64(1<<63 - 1)
hashes := make(map[chainhash.Hash]struct{})
for i := range prevMsgs {
hashes[prevMsgs[i]] = struct{}{}
prExpiry := prs[i].Expires()
if expiry > prExpiry {
expiry = prExpiry
Expand Down Expand Up @@ -815,9 +908,16 @@ func (p *Pool) acceptEntry(msg mixing.Message, msgtype int, hash *chainhash.Hash

count := &rs.counts[msgtype-1] // msgtypes start at 1
*count++
if *count == rs.npeers {
ses.bc.signal()
}
// XXX: testing always signaling for any message accepted to this session.
// waiters must check for signal in a loop
ses.bc.signal()
// if *count == rs.npeers {
// log.Debugf("broadcasting signal for msgtype=%d", msgtype)
// ses.bc.signal()
// } else {
// log.Debugf("not broadcasting msgtype=%d signal: count=%d rs.npeers=%d",
// msgtype, *count, rs.npeers)
// }

return nil
}
Expand Down
23 changes: 22 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3520,6 +3520,26 @@ func (c *reloadableTLSConfig) configFileClient(_ *tls.ClientHelloInfo) (*tls.Con
return c.cachedConfig, nil
}

// mixpoolChain adapts the internal blockchain type with a FetchUtxoEntry
// method that is compatible with the mixpool package.
type mixpoolChain struct {
*blockchain.BlockChain
}

var _ mixpool.BlockChain = (*mixpoolChain)(nil)
var _ mixpool.UtxoEntry = (*blockchain.UtxoEntry)(nil)

func (m *mixpoolChain) FetchUtxoEntry(op wire.OutPoint) (mixpool.UtxoEntry, error) {
entry, err := m.BlockChain.FetchUtxoEntry(op)
if err != nil {
return nil, err
}
if entry == nil {
return nil, err
}
return entry, nil
}

// makeReloadableTLSConfig returns a TLS configuration that will dynamically
// reload the server certificate, server key, and client CAs from the configured
// paths when the files are updated.
Expand Down Expand Up @@ -3866,7 +3886,8 @@ func newServer(ctx context.Context, listenAddrs []string, db database.DB,
}
s.txMemPool = mempool.New(&txC)

s.mixMsgPool = mixpool.NewPool(s.chain)
mixchain := &mixpoolChain{s.chain}
s.mixMsgPool = mixpool.NewPool(mixchain)

s.syncManager = netsync.New(&netsync.Config{
PeerNotifier: &s,
Expand Down

0 comments on commit 1ac206c

Please sign in to comment.