From 2f8bdd3c55c2257c7c4ebf5f5002237fed8f6bdf Mon Sep 17 00:00:00 2001 From: Josh Rickmar Date: Tue, 11 Jun 2024 18:44:42 +0000 Subject: [PATCH] mixpool: Remove Receive expectedMessages argument The number of expected messages can be determined by the capacity of the intended message type's slice. --- mixing/mixclient/blame.go | 8 +++---- mixing/mixclient/client.go | 17 +++++++------- mixing/mixpool/mixpool.go | 48 ++++++++++++++++++++++---------------- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/mixing/mixclient/blame.go b/mixing/mixclient/blame.go index 248d633ff..887f01d57 100644 --- a/mixing/mixclient/blame.go +++ b/mixing/mixclient/blame.go @@ -80,8 +80,8 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { // Receive currently-revealed secrets rcv := new(mixpool.Received) rcv.Sid = sesRun.sid - rcv.RSs = make([]*wire.MsgMixSecrets, 0, len(sesRun.prs)) - _ = mp.Receive(ctx, 0, rcv) + rcv.RSs = make([]*wire.MsgMixSecrets, 0, 1) + _ = mp.Receive(ctx, rcv) rsHashes := make([]chainhash.Hash, len(rcv.RSs)) for _, rs := range rcv.RSs { rsHashes = append(rsHashes, rs.Hash()) @@ -101,8 +101,8 @@ func (c *Client) blame(ctx context.Context, sesRun *sessionRun) (err error) { } // Wait for all secrets, or timeout. - rcv.RSs = rcv.RSs[:0] - _ = mp.Receive(ctx, len(sesRun.prs), rcv) + rcv.RSs = make([]*wire.MsgMixSecrets, 0, len(sesRun.prs)) + _ = mp.Receive(ctx, rcv) rss := rcv.RSs for _, rs := range rcv.RSs { if idx, ok := identityIndices[rs.Identity]; ok { diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go index 9794bbe1f..ef00b2f5e 100644 --- a/mixing/mixclient/client.go +++ b/mixing/mixclient/client.go @@ -1093,7 +1093,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) ctx, cancel := context.WithDeadline(ctx, d.recvKE) defer cancel() - err = mp.Receive(ctx, len(sesRun.prs), rcv) + err = mp.Receive(ctx, rcv) if ctx.Err() != nil { err = fmt.Errorf("session %x KE receive context cancelled: %w", sesRun.sid[:], ctx.Err()) @@ -1264,7 +1264,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) rcv.KEs = nil rcv.CTs = make([]*wire.MsgMixCiphertexts, 0, len(prs)) rcvCtx, rcvCtxCancel := context.WithDeadline(ctx, d.recvCT) - err = mp.Receive(rcvCtx, len(prs), rcv) + err = mp.Receive(rcvCtx, rcv) rcvCtxCancel() cts := rcv.CTs for _, ct := range cts { @@ -1350,7 +1350,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) rcv.CTs = nil rcv.SRs = make([]*wire.MsgMixSlotReserve, 0, len(prs)) rcvCtx, rcvCtxCancel = context.WithDeadline(ctx, d.recvSR) - err = mp.Receive(rcvCtx, len(prs), rcv) + err = mp.Receive(rcvCtx, rcv) rcvCtxCancel() srs := rcv.SRs for _, sr := range srs { @@ -1428,7 +1428,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) rcv.SRs = nil rcv.DCs = make([]*wire.MsgMixDCNet, 0, len(prs)) rcvCtx, rcvCtxCancel = context.WithDeadline(ctx, d.recvDC) - err = mp.Receive(rcvCtx, len(prs), rcv) + err = mp.Receive(rcvCtx, rcv) rcvCtxCancel() dcs := rcv.DCs for _, dc := range dcs { @@ -1508,7 +1508,7 @@ func (c *Client) run(ctx context.Context, ps *pairedSessions, madePairing *bool) rcv.DCs = nil rcv.CMs = make([]*wire.MsgMixConfirm, 0, len(prs)) rcvCtx, rcvCtxCancel = context.WithDeadline(ctx, d.recvCM) - err = mp.Receive(rcvCtx, len(prs), rcv) + err = mp.Receive(rcvCtx, rcv) rcvCtxCancel() cms := rcv.CMs for _, cm := range cms { @@ -1612,16 +1612,15 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, // We can return a result as soon as we read any valid factored // polynomial message that provides the solutions for this SR // polynomial. - expectedMessages := 1 rcv := &mixpool.Received{ Sid: sesRun.sid, - FPs: make([]*wire.MsgMixFactoredPoly, 0, sesRun.mtot), + FPs: make([]*wire.MsgMixFactoredPoly, 0, 1), } roots := make([]*big.Int, 0, len(a)-1) checkedFPByIdentity := make(map[identity]struct{}) for { rcv.FPs = rcv.FPs[:0] - err := c.mixpool.Receive(ctx, expectedMessages, rcv) + err := c.mixpool.Receive(ctx, rcv) if err != nil { return nil, err } @@ -1663,7 +1662,7 @@ func (c *Client) roots(ctx context.Context, seenSRs []chainhash.Hash, checkedFPByIdentity[fp.Identity] = struct{}{} } - expectedMessages = len(rcv.FPs) + 1 + rcv.FPs = make([]*wire.MsgMixFactoredPoly, 0, len(rcv.FPs)+1) } } diff --git a/mixing/mixpool/mixpool.go b/mixing/mixpool/mixpool.go index 882caec13..78c95d549 100644 --- a/mixing/mixpool/mixpool.go +++ b/mixing/mixpool/mixpool.go @@ -595,8 +595,9 @@ PRLoop: // Received is a parameter for Pool.Receive describing the session and run to // receive messages for, and slices for returning results. A single non-nil -// slice is required and indicates which message slice will be will be -// appended to. Received messages are unsorted. +// slice with capacity of the expected number of messages is required and +// indicates which message slice will be will be appended to. Received +// messages are unsorted. type Received struct { Sid [32]byte KEs []*wire.MsgMixKeyExchange @@ -620,7 +621,7 @@ type Received struct { // r.RSs is nil, Receive immediately returns ErrSecretsRevealed. An // additional call to Receive with a non-nil RSs can be used to receive all of // the secrets after each peer publishes their own revealed secrets. -func (p *Pool) Receive(ctx context.Context, expectedMessages int, r *Received) error { +func (p *Pool) Receive(ctx context.Context, r *Received) error { sid := r.Sid var bc *broadcast @@ -632,30 +633,37 @@ func (p *Pool) Receive(ctx context.Context, expectedMessages int, r *Received) e } bc = &ses.bc - nonNilSlices := 0 - if r.KEs != nil { - nonNilSlices++ + var capSlices, expectedMessages int + if cap(r.KEs) != 0 { + capSlices++ + expectedMessages = cap(r.KEs) } - if r.CTs != nil { - nonNilSlices++ + if cap(r.CTs) != 0 { + capSlices++ + expectedMessages = cap(r.CTs) } - if r.SRs != nil { - nonNilSlices++ + if cap(r.SRs) != 0 { + capSlices++ + expectedMessages = cap(r.SRs) } - if r.DCs != nil { - nonNilSlices++ + if cap(r.DCs) != 0 { + capSlices++ + expectedMessages = cap(r.DCs) } - if r.CMs != nil { - nonNilSlices++ + if cap(r.CMs) != 0 { + capSlices++ + expectedMessages = cap(r.CMs) } - if r.FPs != nil { - nonNilSlices++ + if cap(r.FPs) != 0 { + capSlices++ + expectedMessages = cap(r.FPs) } - if r.RSs != nil { - nonNilSlices++ + if cap(r.RSs) != 0 { + capSlices++ + expectedMessages = cap(r.RSs) } - if nonNilSlices != 1 { - return fmt.Errorf("mixpool: exactly one Received slice must be non-nil") + if capSlices != 1 { + return fmt.Errorf("mixpool: exactly one Received slice must have non-zero capacity") } Loop: