From 92861ed5e666a302d7f5575eff40e87f3679c73d Mon Sep 17 00:00:00 2001 From: Josh Rickmar Date: Tue, 16 Jan 2024 18:48:00 +0000 Subject: [PATCH] mixing: add mixclient package The mixclient package implements a client for the peer-to-peer mixing process. It depends on a mixpool to send and receive mixing messages to and from the network. --- mixing/mixclient/blame.go | 18 + mixing/mixclient/client.go | 900 +++++++++++++++++++++++++++++++++++ mixing/mixclient/coinjoin.go | 288 +++++++++++ mixing/mixclient/errors.go | 7 + mixing/mixclient/limits.go | 71 +++ 5 files changed, 1284 insertions(+) create mode 100644 mixing/mixclient/blame.go create mode 100644 mixing/mixclient/client.go create mode 100644 mixing/mixclient/coinjoin.go create mode 100644 mixing/mixclient/errors.go create mode 100644 mixing/mixclient/limits.go diff --git a/mixing/mixclient/blame.go b/mixing/mixclient/blame.go new file mode 100644 index 0000000000..afca40c5e1 --- /dev/null +++ b/mixing/mixclient/blame.go @@ -0,0 +1,18 @@ +package mixclient + +// blamedIdentities identifies detected misbehaving peers. +// +// If a run returns a blamedIdentities error, these peers are immediately +// excluded and the next run is started. +// +// If a run errors but blame requires revealing secrets and blame assignment, +// a blamedIdentities error will be returned by the blame function. +type blamedIdentities []identity + +func (e blamedIdentities) Error() string { + return "blamed assigned" +} + +func (e blamedIdentities) blamed() []identity { + return ([]identity)(e) +} diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go new file mode 100644 index 0000000000..724c075dae --- /dev/null +++ b/mixing/mixclient/client.go @@ -0,0 +1,900 @@ +// Copyright (c) 2023 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package mixclient + +import ( + "bytes" + "context" + cryptorand "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "io" + "math/big" + "math/bits" + "sort" + "time" + + "decred.org/cspp/v2/solverrpc" + "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/internal/chacha20prng" + "github.com/decred/dcrd/mixing/mixpool" + "github.com/decred/dcrd/wire" +) + +// MinPeers is the minimum number of peers required for a mix run to proceed. +const MinPeers = 3 + +// Wallet signs mix transactions and listens for and broadcasts mixing +// protocol messages. +// +// While wallets are responsible for generating mixed addresses, this duty is +// performed by the generator function provided to NewCoinJoin rather than +// this interface. This allows each CoinJoin to pass in a generator closure +// for different BIP0032 accounts and branches. +type Wallet interface { + // Mixpool returns access to the wallet's mixing message pool. + // + // The mixpool should only be used for message access and deletion, + // but never publishing; SubmitMixMessage must be used instead for + // message publishing. + Mixpool() *mixpool.Pool + + // SubmitMixMessage submits a mixing message to the wallet's mixpool + // and broadcasts it to the network. + SubmitMixMessage(ctx context.Context, msg mixing.Message) error + + // SignInput adds a signature script to a transaction input. + SignInput(tx *wire.MsgTx, index int, prevScript []byte) error + + // PublishTransaction adds the transaction to the wallet and publishes + // it to the network. + PublishTransaction(ctx context.Context, tx *wire.MsgTx) error +} + +type deadlines struct { + recvKE time.Time + recvCT time.Time + recvSR time.Time + recvDC time.Time + recvCM time.Time +} + +func (d *deadlines) reset(begin time.Time) { + t := begin + add := func() time.Time { + t = t.Add(30 * time.Second) + return t + } + d.recvKE = add() + d.recvCT = add() + d.recvSR = add() + d.recvDC = add() + d.recvCM = add() +} + +func (d *deadlines) rerun() { + d.reset(d.recvCM) +} + +// Session represents the client for a peer-to-peer mixing session consisting +// of one or more runs to remove unresponsive or malicious peers. +type Session struct { + wallet Wallet + + deadlines deadlines + epoch time.Duration + + pub *secp256k1.PublicKey + priv *secp256k1.PrivateKey + id *[33]byte + + sid [32]byte + commitment [32]byte + srMsg []*big.Int // random numbers for the exponential slot reservation mix + dcMsg wire.MixVect // anonymized messages to publish in XOR mix + + rand io.Reader // non-PRNG cryptographic rand + + prngSeed *[32]byte + prng *chacha20prng.Reader + + kx *mixing.KX + + coinjoin *CoinJoin + expires uint32 + mcount uint32 + + freshGen bool // Whether next run must generate fresh KX keys, SR/DC messages + + logger Logger +} + +type Logger interface { + Log(args ...interface{}) + Logf(format string, args ...interface{}) +} + +func (s *Session) SetLogger(l Logger) { + s.logger = l +} + +func (s *Session) log(args ...interface{}) { + if s.logger == nil { + return + } + + s.logger.Log(args...) +} + +func (s *Session) logf(format string, args ...interface{}) { + if s.logger == nil { + return + } + + s.logger.Logf(format, args...) +} + +func generateSecp256k1(rand io.Reader) (*secp256k1.PublicKey, *secp256k1.PrivateKey, error) { + if rand == nil { + rand = cryptorand.Reader + } + + privateKey, err := secp256k1.GeneratePrivateKeyFromRand(rand) + if err != nil { + return nil, nil, err + } + + publicKey := privateKey.PubKey() + + return publicKey, privateKey, nil +} + +// NewSession creates a new mixing session client for a coinjoin mix transaction. +func NewSession(w Wallet, rand io.Reader, coinjoin *CoinJoin) (*Session, error) { + pub, priv, err := generateSecp256k1(rand) + if err != nil { + return nil, err + } + ses := &Session{ + wallet: w, + epoch: 10 * time.Minute, + pub: pub, + priv: priv, + id: (*[33]byte)(pub.SerializeCompressed()), + rand: rand, + coinjoin: coinjoin, + expires: coinjoin.prExpiry, + mcount: coinjoin.mcount, + freshGen: true, + } + + return ses, nil +} + +// SetEpoch modifies the session to use a longer or shorter epoch duration. +// The default epoch is 10 minutes. +func (s *Session) SetEpoch(epoch time.Duration) { + s.epoch = epoch +} + +// waitForEpoch blocks until the next epoch, or errors when the context is +// cancelled early. Returns the calculated epoch for stage timeout +// calculations. +func (s *Session) waitForEpoch(ctx context.Context) (time.Time, error) { + now := time.Now().UTC() + epoch := now.Truncate(s.epoch).Add(s.epoch) + duration := epoch.Sub(now) + timer := time.NewTimer(duration) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return epoch, ctx.Err() + case <-timer.C: + return epoch, nil + } +} + +// sleepUntil blocks until the next synchronization point, or errors when the +// context is cancelled early. +func (s *Session) sleepUntil(ctx context.Context, until time.Time) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Until(until)): + return nil + } +} + +// Dicemix performs the dicemix mixing session. +func (s *Session) Dicemix(ctx context.Context) error { + mp := s.wallet.Mixpool() + + pr, err := wire.NewMsgMixPairReq(*s.id, s.expires, s.coinjoin.mixValue, + string(mixing.ScriptClassP2PKHv0), s.coinjoin.tx.Version, + s.coinjoin.tx.LockTime, uint32(s.coinjoin.mcount), + s.coinjoin.inputValue, s.coinjoin.prUTXOs, s.coinjoin.change) + if err != nil { + return err + } + pairingID, err := pr.Pairing() + if err != nil { + return err + } + err = mixing.SignMessage(pr, s.priv) + if err != nil { + return err + } + err = s.wallet.SubmitMixMessage(ctx, pr) + if err != nil { + return fmt.Errorf("submit PR: %w", err) + } + + var epoch time.Time + var prs []*wire.MsgMixPairReq + for { + epoch, err = s.waitForEpoch(ctx) + if err != nil { + return err + } + + prs = mp.CompatiblePRs(pairingID) + if len(prs) >= MinPeers { + break + } + } + + newSession := true + var expiry uint32 + s.deadlines.reset(epoch) + for i := uint32(0); ; i++ { + // Calculate new deadlines for reruns and repaired sessions. + if i != 0 { + s.deadlines.rerun() + } + if newSession { + sort.Slice(prs, func(i, j int) bool { + a := prs[i].Hash() + b := prs[j].Hash() + return bytes.Compare(a[:], b[:]) == -1 + }) + + prHashes := make([]chainhash.Hash, len(prs)) + for i := range prs { + prHashes[i] = prs[i].Hash() + } + sid := mixing.DeriveSessionID(prHashes) + s.sid = sid + + // Session expires with the earliest PR expiry + for _, pr := range prs { + if expiry == 0 || pr.Expiry < expiry { + expiry = pr.Expiry + } + } + + i = 0 + newSession = false + } + + err := s.run(ctx, i, expiry, prs) + var recreatedSessionErr *recreatedSessionError + if errors.As(err, &recreatedSessionErr) { + prs = recreatedSessionErr.prs + + if len(prs) < MinPeers { + return ErrTooFewPeers + } + + newSession = true + continue + } + + var excludePeersErr excludePeersError + if errors.As(err, &excludePeersErr) { + prs = excludePeersErr.Rerun() + continue + } + if err != nil { + // blame and rerun + return err + } + return nil + } +} + +var ( + errRunStageTimeout = errors.New("mix run stage timeout") + errUnblamedRerun = errors.New("rerun") +) + +type recreatedSessionError struct { + prs []*wire.MsgMixPairReq +} + +func (e *recreatedSessionError) Error() string { return "recreated session" } + +type excludePeersError interface { + error + Exclude() []identity + Rerun() []*wire.MsgMixPairReq +} + +type timeoutError struct { + exclude []identity + prs []*wire.MsgMixPairReq +} + +func (e *timeoutError) Error() string { return "timeout" } +func (e *timeoutError) Exclude() []identity { return e.exclude } +func (e *timeoutError) Rerun() []*wire.MsgMixPairReq { return e.prs } + +func (s *Session) run(ctx context.Context, run uint32, expiry uint32, prs []*wire.MsgMixPairReq) error { + var blamed blamedIdentities + + mp := s.wallet.Mixpool() + + if len(prs) < MinPeers { + return ErrTooFewPeers + } + pairingID, err := prs[0].Pairing() + if err != nil { + return err + } + + sort.Slice(prs, func(i, j int) bool { + a := prs[i].Hash() + b := prs[j].Hash() + return bytes.Compare(a[:], b[:]) == -1 + }) + + // A map of identity public keys to their PR position sort all + // messages in the same order as the PRs are ordered. + identityIndices := make(map[identity]int) + for i, pr := range prs { + identityIndices[pr.Identity] = i + } + + s.coinjoin.reset(prs) + + vk := make([]*secp256k1.PublicKey, len(prs)) + for i := range prs { + pub, err := secp256k1.ParsePubKey(prs[i].Identity[:]) + if err != nil { + return err + } + vk[i] = pub + } + + var mtot uint32 + var myVk uint32 + var myStart uint32 + foundSelf := false + mcounts := make([]uint32, len(vk)) + for i := range vk { + pubSerialized := vk[i].SerializeCompressed() + if bytes.Equal(pubSerialized, s.id[:]) { + foundSelf = true + myVk = uint32(i) + myStart = mtot + } + if prs[i].MessageCount == 0 { + return errors.New("non-positive message count") + } + mtot += prs[i].MessageCount + mcounts[i] = prs[i].MessageCount + } + if !foundSelf { + return errors.New("failed to find own PR") + } + + type sessionRun struct { + session *Session + run uint32 + + // Peer information + vk []*secp256k1.PublicKey // session pubkeys + mcounts []uint32 + mtot uint32 + myVk uint32 + myStart uint32 + + // Exponential slot reservation mix + srKP [][][]byte // shared keys for exp dc-net + srMix [][]*big.Int + srMixBytes [][][]byte + + // XOR DC-net + dcKP [][]wire.MixVect + dcNet []wire.MixVect + } + r := &sessionRun{ + session: s, + run: run, + mcounts: mcounts, + mtot: mtot, + vk: vk, + myVk: myVk, + myStart: myStart, + } + + if s.freshGen { + s.freshGen = false + + // Generate a new PRNG seed + s.prngSeed = new([32]byte) + _, err = io.ReadFull(s.rand, s.prngSeed[:]) + if err != nil { + return err + } + s.prng = chacha20prng.New(s.prngSeed[:], run) + + // Generate fresh keys from this run's PRNG + var err error + s.kx, err = mixing.NewKX(s.prng) + if err != nil { + return err + } + + // Generate fresh SR messages + s.srMsg = make([]*big.Int, s.mcount) + for i := range s.srMsg { + s.srMsg[i], err = cryptorand.Int(s.rand, mixing.F) + if err != nil { + return err + } + } + + // Generate fresh DC messages + s.dcMsg, err = s.coinjoin.gen() + if err != nil { + return err + } + if len(s.dcMsg) != int(s.mcount) { + return errors.New("Gen returned wrong message count") + } + for _, m := range s.dcMsg { + if len(m) != msize { + err := fmt.Errorf("Gen returned bad message "+ + "length [%v != %v]", len(m), msize) + return err + } + } + } else { + // Generate a new PRNG from existing seed and this run number. + s.prng = chacha20prng.New(s.prngSeed[:], run) + } + + seenPRs := make([]chainhash.Hash, len(prs)) + for i := range prs { + seenPRs[i] = prs[i].Hash() + } + + // Perform key exchange + srMsgBytes := make([][]byte, len(s.srMsg)) + for i := range s.srMsg { + srMsgBytes[i] = s.srMsg[i].Bytes() + } + rs := wire.NewMsgMixSecrets(*s.id, s.sid, expiry, run, *s.prngSeed, srMsgBytes, s.dcMsg) + commitment := rs.Hash() + ecdhPub := *(*[33]byte)(s.kx.ECDHPublicKey.SerializeCompressed()) + pqPub := *s.kx.PQPublicKey + ke := wire.NewMsgMixKeyExchange(*s.id, s.sid, expiry, run, ecdhPub, pqPub, commitment, seenPRs) + err = mixing.SignMessage(ke, s.priv) + if err != nil { + return err + } + err = s.wallet.SubmitMixMessage(ctx, ke) + if err != nil { + return fmt.Errorf("submit KE: %w", err) + } + + // In run 0, the run can proceed when there are KE messages from each + // peer selected for this session and each KE refers to known PR + // messages. This is done using mp.Receive() as it only finds + // messages matching the session. + // + // If this is not the case, we must find a different session that + // other peers are able to participate in. This must be a subset of + // the original PRs that peers have seen, and optimizes for including + // the most peers. This is done using mp.ReceiveKEs(), which returns + // all KEs matching a PR pairing, even if they began in different + // sessions. + rcv := new(mixpool.Received) + rcv.Run = run + rcv.Sid = s.sid + rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(vk)) + rcvCtx, rcvCtxCancel := context.WithDeadlineCause(ctx, + s.deadlines.recvKE, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(vk), rcv) + rcvCtxCancel() + kes := rcv.KEs + if len(kes) != len(vk) || errors.Is(err, errRunStageTimeout) { + // Remove any unresponsive peers that sent PRs, but did not + // attempt to start or continue the session by sending a KE. + unresponsive := make(map[identity]struct{}) + for _, pr := range prs { + unresponsive[pr.Identity] = struct{}{} + } + for _, ke := range prs { + delete(unresponsive, ke.Identity) + } + if len(unresponsive) != 0 { + identities := make([]identity, 0, len(unresponsive)) + for id := range unresponsive { + identities = append(identities, id) + } + mp.RemoveIdentities(identities) + } + + // Rerun keeping the same session if peers dropped out + if run != 0 { + s.logf("received %d KEs for %d peers; rerunning", len(kes), len(vk)) + return errUnblamedRerun + } + + // Based on the seen pairing data available, find an + // alternative session. + if err := s.sleepUntil(ctx, s.deadlines.recvKE); err != nil { + return err + } + prs := mp.CompatiblePRs(pairingID) // shadowed + kes := mp.ReceiveKEs(pairingID) + + counts := make(map[chainhash.Hash]int) + maxCount := 0 + for _, ke := range kes { + for _, prHash := range ke.SeenPRs { + counts[prHash]++ + count := counts[prHash] + if maxCount < count { + maxCount = count + } + } + } + kept := 0 + keptIdentities := make(map[identity]struct{}) + for _, pr := range prs { + prHash := pr.Hash() + if counts[prHash] == maxCount { + prs[kept] = pr + keptIdentities[pr.Identity] = struct{}{} + } + } + prs = prs[:kept] + + if len(prs) != maxCount { + // We have not seen the PR messages required to + // participate in the new session. + return errors.New("aborted session") + } + + // Signal caller to begin a new session with these PRs. + return &recreatedSessionError{prs: prs} + } + if err != nil { + return err + } + sort.Slice(kes, func(i, j int) bool { + a := identityIndices[kes[i].Identity] + b := identityIndices[kes[j].Identity] + return a < b + }) + + revealed := mixing.RevealedKeys{ + ECDHPublicKeys: make([]*secp256k1.PublicKey, 0, len(vk)), + Ciphertexts: make([]mixing.PQCiphertext, 0, len(vk)), + MyIndex: r.myVk, + } + pqpk := make([]*mixing.PQPublicKey, 0, len(vk)) + for _, ke := range kes { + ecdhPub, err := secp256k1.ParsePubKey(ke.ECDH[:]) + if err != nil { + blamed = append(blamed, ke.Identity) + continue + } + revealed.ECDHPublicKeys = append(revealed.ECDHPublicKeys, ecdhPub) + pqpk = append(pqpk, &ke.PQPK) + } + if len(blamed) > 0 { + return blamed + } + + // Create shared keys and ciphextexts for each peer + pqct, err := s.kx.Encapsulate(s.prng, pqpk, int(r.myVk)) + if err != nil { + return err + } + + // Sent ciphertext messages + seenKEs := make([]chainhash.Hash, len(kes)) + for i := range kes { + seenKEs[i] = kes[i].Hash() + } + ct := wire.NewMsgMixCiphertexts(*s.id, s.sid, expiry, run, pqct, seenKEs) + err = mixing.SignMessage(ct, s.priv) + if err != nil { + return err + } + err = s.wallet.SubmitMixMessage(ctx, ct) + if err != nil { + return fmt.Errorf("submit CT: %w", err) + } + + // Receive all ciphertext messages + rcv.KEs = nil + rcv.CTs = make([]*wire.MsgMixCiphertexts, 0, len(vk)) + rcv.RSs = nil // XXX + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(ctx, + s.deadlines.recvCT, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(vk), rcv) + rcvCtxCancel() + cts := rcv.CTs + if len(cts) != len(vk) || errors.Is(err, errRunStageTimeout) { + // Blame peers + s.logf("received %d CTs for %d peers; rerunning", len(cts), len(vk)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(cts, func(i, j int) bool { + a := identityIndices[cts[i].Identity] + b := identityIndices[cts[j].Identity] + return a < b + }) + for i := range cts { + revealed.Ciphertexts = append(revealed.Ciphertexts, cts[i].Ciphertexts[r.myVk]) + } + + // Derive shared secret keys + shared, err := r.session.kx.SharedSecrets(&revealed, s.sid[:], r.run, r.mcounts) + if err != nil { + return err + } + r.srKP = shared.SRSecrets + r.dcKP = shared.DCSecrets + + // Calculate slot reservation DC-net vectors + r.srMix = make([][]*big.Int, s.mcount) + for i := 0; i < int(s.mcount); i++ { + pads := mixing.SRMixPads(r.srKP[i], myStart+uint32(i)) + r.srMix[i] = mixing.SRMix(s.srMsg[i], pads) + } + srMixBytes := mixing.IntVectorsToBytes(r.srMix) + + // Broadcast message commitment and exponential DC-mix vectors for slot + // reservations. + seenCTs := make([]chainhash.Hash, len(cts)) + for i := range cts { + seenCTs[i] = cts[i].Hash() + } + sr := wire.NewMsgMixSlotReserve(*s.id, s.sid, expiry, r.run, srMixBytes, seenCTs) + err = mixing.SignMessage(sr, s.priv) + if err != nil { + return err + } + err = s.wallet.SubmitMixMessage(ctx, sr) + if err != nil { + return fmt.Errorf("submit SR: %w", err) + } + + // Receive all slot reservation messages + rcv.CTs = nil + rcv.SRs = make([]*wire.MsgMixSlotReserve, 0, len(vk)) + rcv.RSs = nil // XXX + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(ctx, + s.deadlines.recvSR, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(vk), rcv) + rcvCtxCancel() + srs := rcv.SRs + if len(srs) != len(vk) || errors.Is(err, errRunStageTimeout) { + // Blame peers + s.logf("received %d SRs for %d peers; rerunning", len(srs), len(vk)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(srs, func(i, j int) bool { + a := identityIndices[srs[i].Identity] + b := identityIndices[srs[j].Identity] + return a < b + }) + + // Recover roots + var roots []*big.Int + vs := make([][][]byte, 0, len(r.vk)) + for _, sr := range srs { + vs = append(vs, sr.DCMix...) + } + powerSums := mixing.AddVectors(mixing.IntVectorsFromBytes(vs)...) + coeffs := mixing.Coefficients(powerSums) + roots, err = solverrpc.Roots(coeffs, mixing.F) + if err != nil { + // Blame peers + return errors.New("blame required") + } + + // Find reserved slots + slots := make([]uint32, 0, s.mcount) + for _, m := range s.srMsg { + slot := constTimeSlotSearch(m, roots) + if slot == -1 { + // Blame peers + return errors.New("blame required") + } + slots = append(slots, uint32(slot)) + } + + // Calculate XOR DC-net vectors + r.dcNet = make([]wire.MixVect, s.mcount) + for i, slot := range slots { + my := r.myStart + uint32(i) + pads := mixing.DCMixPads(r.dcKP[i], my) + r.dcNet[i] = wire.MixVect(mixing.DCMix(pads, s.dcMsg[i][:], slot)) + } + + // Broadcast XOR DC-net vectors. + seenSRs := make([]chainhash.Hash, len(cts)) + for i := range srs { + seenSRs[i] = srs[i].Hash() + } + dc := wire.NewMsgMixDCNet(*s.id, s.sid, expiry, r.run, r.dcNet, seenSRs) + err = mixing.SignMessage(dc, s.priv) + if err != nil { + return err + } + err = s.wallet.SubmitMixMessage(ctx, dc) + if err != nil { + return fmt.Errorf("submit DC: %w", err) + } + + // Receive all DC messages + rcv.SRs = nil + rcv.DCs = make([]*wire.MsgMixDCNet, 0, len(vk)) + rcv.RSs = nil // XXX + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(ctx, + s.deadlines.recvDC, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(vk), rcv) + rcvCtxCancel() + dcs := rcv.DCs + if len(dcs) != len(vk) || errors.Is(err, errRunStageTimeout) { + // Blame peers + s.logf("received %d DCs for %d peers; rerunning", len(dcs), len(vk)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(dcs, func(i, j int) bool { + a := identityIndices[dcs[i].Identity] + b := identityIndices[dcs[j].Identity] + return a < b + }) + + // Solve XOR dc-net + dcVecs := make([]mixing.Vec, 0, mtot) + for _, dc := range dcs { + for _, vec := range dc.DCNet { + dcVecs = append(dcVecs, mixing.Vec(vec)) + } + } + res := mixing.XorVectors(dcVecs) + + // Add outputs for each mixed message + for i := range res { + mixedMsg := res[i][:] + s.coinjoin.addMixedMessage(mixedMsg) + } + s.coinjoin.sort() + + // Confirm that our messages and change are present + err = s.coinjoin.confirm(s.wallet) + if err != nil { + return err + } + if err != nil { + type missingMessage interface { + error + MissingMessage() + } + var errMM missingMessage + if errors.As(err, &errMM) { + // Blame peers + s.logf("missing message; rerunning", len(cts), len(vk)) + return errUnblamedRerun + } + return err + } + + // Broadcast partially signed mix tx + seenDCs := make([]chainhash.Hash, len(dcs)) + for i := range dcs { + seenDCs[i] = dcs[i].Hash() + } + cm := wire.NewMsgMixConfirm(*s.id, s.sid, expiry, r.run, + s.coinjoin.Tx().Copy(), seenDCs) + err = mixing.SignMessage(cm, s.priv) + if err != nil { + return err + } + err = s.wallet.SubmitMixMessage(ctx, cm) + if err != nil { + return fmt.Errorf("submit CM: %w", err) + } + + // Receive all CM messages + rcv.DCs = nil + rcv.CMs = make([]*wire.MsgMixConfirm, 0, len(vk)) + rcv.RSs = nil // XXX + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(ctx, + s.deadlines.recvCM, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(vk), rcv) + rcvCtxCancel() + cms := rcv.CMs + if len(cms) != len(vk) || errors.Is(err, errRunStageTimeout) { + // Blame peers + s.logf("received %d CMs for %d peers; rerunning", len(cms), len(vk)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(cms, func(i, j int) bool { + a := identityIndices[cms[i].Identity] + b := identityIndices[cms[j].Identity] + return a < b + }) + + // Merge signatures + for _, cm := range cms { + err := s.coinjoin.mergeSignatures(cm) + if err != nil { + blamed = append(blamed, cm.Identity) + } + } + if len(blamed) > 0 { + return blamed + } + + err = s.wallet.PublishTransaction(ctx, s.coinjoin.tx) + if err != nil { + return err + } + + mp.RemoveSession(s.sid, prs) + + return nil +} + +var fieldLen = uint(len(mixing.F.Bytes())) + +// constTimeSlotSearch searches for the index of secret in roots in constant time. +// Returns -1 if the secret is not found. +func constTimeSlotSearch(secret *big.Int, roots []*big.Int) int { + paddedSecret := make([]byte, fieldLen) + secretBytes := secret.Bytes() + off, _ := bits.Sub(fieldLen, uint(len(secretBytes)), 0) + copy(paddedSecret[off:], secretBytes) + + slot := -1 + buf := make([]byte, fieldLen) + for i := range roots { + rootBytes := roots[i].Bytes() + off, _ := bits.Sub(fieldLen, uint(len(rootBytes)), 0) + copy(buf[off:], rootBytes) + cmp := subtle.ConstantTimeCompare(paddedSecret, buf) + slot = subtle.ConstantTimeSelect(cmp, i, slot) + for j := range buf { + buf[j] = 0 + } + } + return slot +} diff --git a/mixing/mixclient/coinjoin.go b/mixing/mixclient/coinjoin.go new file mode 100644 index 0000000000..de406c39cd --- /dev/null +++ b/mixing/mixclient/coinjoin.go @@ -0,0 +1,288 @@ +// Copyright (c) 2023 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package mixclient + +import ( + "crypto/subtle" + "errors" + + "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/decred/dcrd/dcrutil/v4/txsort" + "github.com/decred/dcrd/mixing/utxoproof" + "github.com/decred/dcrd/txscript/v4" + "github.com/decred/dcrd/wire" +) + +// msize is the message size of a mixed message (hash160). +const msize = 20 + +var ( + // errMissingGen indicates one or more dishonest peers in the DC-net + // that must be removed by revealing secrets, assigning blame, and + // rerunning with them excluded. + errMissingGen = errors.New("coinjoin is missing gen output") + + // errSignedWrongTx indicates a peer signed a different transaction + // than the coinjoin transaction that we constructed. Peers who sent + // these invalid CM messages must be removed in the next rerun. + errSignedWrongTx = errors.New("peer signed incorrect mix transaction") +) + +// identity is a peer's secp256k1 compressed public key. +type identity = [33]byte + +// CoinJoin tracks the in-progress coinjoin transaction for a single peer as +// the mixing protocol is performed. +type CoinJoin struct { + txHash chainhash.Hash + + genFunc GenFunc + + tx *wire.MsgTx + change *wire.TxOut + prevScripts map[wire.OutPoint][]byte + peerPRs map[identity]*wire.MsgMixPairReq + contributed map[wire.OutPoint]identity + prUTXOs []wire.MixPairReqUTXO + myInputs []int + gens wire.MixVect + genScripts [][]byte + mixedIndices []int + + mixValue int64 + inputValue int64 + prExpiry uint32 + mcount uint32 +} + +// GenFunc generates fresh secp256k1 P2PKH hash160s from the wallet. +type GenFunc func(count uint32) (wire.MixVect, error) + +// NewCoinJoin creates the initial coinjoin transaction. Inputs must be +// contributed to the coinjoin by one or more calls to AddInput. +func NewCoinJoin(gen GenFunc, change *wire.TxOut, mixValue int64, prExpiry uint32, mcount uint32) *CoinJoin { + return &CoinJoin{ + genFunc: gen, + tx: wire.NewMsgTx(), + change: change, + peerPRs: make(map[identity]*wire.MsgMixPairReq), + contributed: make(map[wire.OutPoint]identity), + prevScripts: make(map[wire.OutPoint][]byte), + mixValue: mixValue, + prExpiry: prExpiry, + mcount: mcount, + } +} + +// AddInput adds an contributed input to the coinjoin transaction. +// +// The private key is used to generate a UTXO signature proof demonstrating +// that this wallet is able to spend the UTXO. The private key is not +// retained by the CoinJoin structure and may be zerod from memory after +// AddInput returns. +func (c *CoinJoin) AddInput(input *wire.TxIn, prevScript []byte, prevScriptVersion uint16, + privKey *secp256k1.PrivateKey) error { + + pub := privKey.PubKey().SerializeCompressed() + keyPair := utxoproof.Secp256k1KeyPair{ + Priv: privKey, + Pub: pub, + } + proofSig, err := keyPair.SignUtxoProof(c.prExpiry) + if err != nil { + return err + } + + c.prUTXOs = append(c.prUTXOs, wire.MixPairReqUTXO{ + OutPoint: input.PreviousOutPoint, + Script: nil, // Only for P2SH + PubKey: pub, + Signature: proofSig, + }) + + c.prevScripts[input.PreviousOutPoint] = prevScript + c.inputValue += input.ValueIn + + return nil +} + +// reset initializes the coinjoin transaction with all peers' unmixed data +// from their pair request messages. +func (c *CoinJoin) reset(prs []*wire.MsgMixPairReq) { + c.tx.TxIn = c.tx.TxIn[:0] + c.tx.TxOut = c.tx.TxOut[:0] + c.mixedIndices = c.mixedIndices[:0] + c.myInputs = c.myInputs[:0] + for id := range c.peerPRs { + delete(c.peerPRs, id) + } + for outpoint := range c.contributed { + delete(c.contributed, outpoint) + } + + for _, pr := range prs { + c.peerPRs[pr.Identity] = pr + for i := range pr.UTXOs { + prevOutPoint := &pr.UTXOs[i].OutPoint + in := wire.NewTxIn(prevOutPoint, wire.NullValueIn, nil) + c.tx.AddTxIn(in) + c.contributed[*prevOutPoint] = pr.Identity + } + if pr.Change != nil { + c.tx.AddTxOut(pr.Change) + } + } +} + +// gen calls the message generator function, recording and returning the +// freshly generated messages to be mixed. +func (c *CoinJoin) gen() (wire.MixVect, error) { + gens, err := c.genFunc(c.mcount) + if err != nil { + return nil, err + } + + genScripts := make([][]byte, len(gens)) + for i, m := range gens { + script := make([]byte, 25) + script[0] = txscript.OP_DUP + script[1] = txscript.OP_HASH160 + script[2] = txscript.OP_DATA_20 + copy(script[3:23], m[:]) + script[23] = txscript.OP_EQUALVERIFY + script[24] = txscript.OP_CHECKSIG + genScripts[i] = script + } + + c.gens = gens + c.genScripts = genScripts + return gens, nil +} + +// addMixedMessage adds a transaction output paying to the mixed hash160 +// message. +func (c *CoinJoin) addMixedMessage(m []byte) { + if len(m) != msize { + return + } + + script := make([]byte, 25) + script[0] = txscript.OP_DUP + script[1] = txscript.OP_HASH160 + script[2] = txscript.OP_DATA_20 + copy(script[3:23], m[:]) + script[23] = txscript.OP_EQUALVERIFY + script[24] = txscript.OP_CHECKSIG + + c.tx.AddTxOut(wire.NewTxOut(c.mixValue, script)) +} + +// sort performs an in-place sort of the transaction, retaining any and all +// internal bookkeeping about which inputs and outputs are contributed by the +// client for the confirmation checks. sort must be called before any +// signatures are created or included so that all peers deterministicly agree +// on the trasaction to sign. +func (c *CoinJoin) sort() { + txsort.InPlaceSort(c.tx) + + c.myInputs = c.myInputs[:0] + for i, in := range c.tx.TxIn { + _, ok := c.prevScripts[in.PreviousOutPoint] + if ok { + c.myInputs = append(c.myInputs, i) + } + } + + c.txHash = c.tx.TxHash() +} + +// constantTimeOutputSearch searches for the output indices of mixed outputs to +// verify inclusion in a coinjoin. It is constant time such that, for each +// searched script, all outputs with equal value, script versions, and script +// lengths matching the searched output are checked in constant time. +func constantTimeOutputSearch(tx *wire.MsgTx, value int64, scriptVer uint16, scripts [][]byte) ([]int, error) { + var scan []int + for i, out := range tx.TxOut { + if out.Value != value { + continue + } + if out.Version != scriptVer { + continue + } + if len(out.PkScript) != len(scripts[0]) { + continue + } + scan = append(scan, i) + } + indices := make([]int, 0, len(scan)) + var missing int + for _, s := range scripts { + idx := -1 + for _, i := range scan { + eq := subtle.ConstantTimeCompare(tx.TxOut[i].PkScript, s) + idx = subtle.ConstantTimeSelect(eq, i, idx) + } + indices = append(indices, idx) + eq := subtle.ConstantTimeEq(int32(idx), -1) + missing = subtle.ConstantTimeSelect(eq, 1, missing) + } + if missing == 1 { + return nil, errMissingGen + } + return indices, nil +} + +// confirm ensures that generated messages are present in the transaction and +// signs inputs being contributed by the peer. returns errMissingGen to +// trigger blame assignment if a message is not found. +func (c *CoinJoin) confirm(wallet Wallet) error { + genIndices, err := constantTimeOutputSearch(c.tx, c.mixValue, 0 /* XXX */, c.genScripts) + if err != nil { + return err + } + + for _, input := range c.myInputs { + prevOutpoint := &c.tx.TxIn[input].PreviousOutPoint + err := wallet.SignInput(c.tx, input, c.prevScripts[*prevOutpoint]) + if err != nil { + return err + } + } + + c.mixedIndices = genIndices + return nil +} + +// mergeSignatures adds the signatures from another peer's CM message to the +// coinjoin transaction. Only those inputs contributed by the peer are +// modified. +// +// Peers must be removed in the next run if mergeSignatures returns an error. +func (c *CoinJoin) mergeSignatures(cm *wire.MsgMixConfirm) error { + // Signatures may only be used if an identical transaction was signed. + if cm.Mix.TxHash() != c.txHash { + return errSignedWrongTx + } + + for i, in := range cm.Mix.TxIn { + if cm.Identity != c.contributed[in.PreviousOutPoint] { + continue + } + c.tx.TxIn[i].SignatureScript = in.SignatureScript + } + + return nil +} + +// Tx returns the coinjoin transaction. +func (c *CoinJoin) Tx() *wire.MsgTx { + return c.tx +} + +// MixedIndices returns the peer's mixed transaction output indices. +func (c *CoinJoin) MixedIndices() []int { + return c.mixedIndices +} diff --git a/mixing/mixclient/errors.go b/mixing/mixclient/errors.go new file mode 100644 index 0000000000..1e22db1aa5 --- /dev/null +++ b/mixing/mixclient/errors.go @@ -0,0 +1,7 @@ +package mixclient + +import "errors" + +var ( + ErrTooFewPeers = errors.New("not enough peers required to mix") +) diff --git a/mixing/mixclient/limits.go b/mixing/mixclient/limits.go new file mode 100644 index 0000000000..29d2800298 --- /dev/null +++ b/mixing/mixclient/limits.go @@ -0,0 +1,71 @@ +package mixclient + +import ( + "errors" + + "github.com/decred/dcrd/wire" +) + +const ( + redeemP2PKHv0SigScriptSize = 1 + 73 + 1 + 33 + p2pkhv0PkScriptSize = 1 + 1 + 1 + 20 + 1 + 1 +) + +func estimateP2PKHv0SerializeSize(inputs, outputs int, hasChange bool) int { + // Sum the estimated sizes of the inputs and outputs. + txInsSize := inputs * estimateInputSize(redeemP2PKHv0SigScriptSize) + txOutsSize := outputs * estimateOutputSize(p2pkhv0PkScriptSize) + + changeSize := 0 + if hasChange { + changeSize = estimateOutputSize(p2pkhv0PkScriptSize) + outputs++ + } + + // 12 additional bytes are for version, locktime and expiry. + return 12 + (2 * wire.VarIntSerializeSize(uint64(inputs))) + + wire.VarIntSerializeSize(uint64(outputs)) + + txInsSize + txOutsSize + changeSize +} + +// estimateInputSize returns the worst case serialize size estimate for a tx input +func estimateInputSize(scriptSize int) int { + return 32 + // previous tx + 4 + // output index + 1 + // tree + 8 + // amount + 4 + // block height + 4 + // block index + wire.VarIntSerializeSize(uint64(scriptSize)) + // size of script + scriptSize + // script itself + 4 // sequence +} + +// estimateOutputSize returns the worst case serialize size estimate for a tx output +func estimateOutputSize(scriptSize int) int { + return 8 + // previous tx + 2 + // version + wire.VarIntSerializeSize(uint64(scriptSize)) + // size of script + scriptSize // script itself +} + +func estimateIsStandardSize(inputs, outputs int) bool { + const maxSize = 100000 + + estimated := estimateP2PKHv0SerializeSize(inputs, outputs, false) + return estimated <= maxSize +} + +// checkLimited determines if adding an peer with the provided unmixed values +// and a total number of mixed outputs would cause the transaction size to +// exceed the maximum allowed size. Peers must be excluded from mixes if +// their contributions would cause the total transaction size to be too large, +// even if they have not acted maliciously in the mixing protocol. +func checkLimited(currentTx, unmixed *wire.MsgTx, totalMessages int) error { + totalInputs := len(currentTx.TxIn) + len(unmixed.TxIn) + totalOutputs := len(currentTx.TxOut) + len(unmixed.TxOut) + totalMessages + if !estimateIsStandardSize(totalInputs, totalOutputs) { + return errors.New("tx size would exceed standardness rules") + } + return nil +}