Skip to content

Commit

Permalink
internal/transport: move the retrier loop to its own type
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisccoulson committed Feb 27, 2024
1 parent 0b1d44f commit 0c33912
Showing 1 changed file with 48 additions and 24 deletions.
72 changes: 48 additions & 24 deletions internal/transportutil/retrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ type RetryParams struct {
}

type retrierTransport struct {
transport tpm2.Transport
params RetryParams

tomb tomb.Tomb
tomb *tomb.Tomb

w io.WriteCloser // write channel

Expand All @@ -49,10 +46,7 @@ type retrierTransport struct {
// NewRetrierTransport returns a new transport that resubmits commands on certain
// errors, which is necessary for transports that don't already do this.
func NewRetrierTransport(transport tpm2.Transport, params RetryParams) tpm2.Transport {
t := &retrierTransport{
transport: transport,
params: params,
}
t := new(retrierTransport)

// Construct the write channel
wr, ww := io.Pipe()
Expand All @@ -70,9 +64,14 @@ func NewRetrierTransport(transport tpm2.Transport, params RetryParams) tpm2.Tran
closeErr := make(chan error)
t.closeErr = closeErr

tmb := new(tomb.Tomb)
t.tomb = tmb

// Run the transport routine
t.tomb.Go(func() error {
err := t.run(wr, rw, rLen, rErr)
tmb.Go(func() error {
loop := newRetrierTransportLoop(&params, transport, tmb, wr, rw, rLen, rErr)
err := loop.run()

// Ensure the calling routine gets unblocked.
wr.Close()
rw.Close()
Expand All @@ -85,25 +84,50 @@ func NewRetrierTransport(transport tpm2.Transport, params RetryParams) tpm2.Tran
return t
}

func (t *retrierTransport) runCommand(commandCode tpm2.CommandCode, data []byte) ([]byte, error) {
retryDelay := t.params.InitialBackoff
type retrierTransportLoop struct {
params RetryParams
transport tpm2.Transport

tomb *tomb.Tomb

r io.Reader

w io.Writer
wLen chan<- int64
wErr chan<- error
}

func newRetrierTransportLoop(params *RetryParams, transport tpm2.Transport, tomb *tomb.Tomb, r io.Reader, w io.Writer, wLen chan<- int64, wErr chan<- error) *retrierTransportLoop {
return &retrierTransportLoop{
params: *params,
transport: transport,
tomb: tomb,
r: r,
w: w,
wLen: wLen,
wErr: wErr,
}
}

func (l *retrierTransportLoop) runCommand(commandCode tpm2.CommandCode, data []byte) ([]byte, error) {
retryDelay := l.params.InitialBackoff

for retries := t.params.MaxRetries; ; retries-- {
if !t.tomb.Alive() {
for retries := l.params.MaxRetries; ; retries-- {
if !l.tomb.Alive() {
return nil, ErrClosed
}

// Send the command.
if _, err := t.transport.Write(data); err != nil {
if _, err := l.transport.Write(data); err != nil {
return nil, fmt.Errorf("cannot send command: %w", err)
}

if !t.tomb.Alive() {
if !l.tomb.Alive() {
return nil, ErrClosed
}

rsp := new(bytes.Buffer)
tr := io.TeeReader(t.transport, rsp)
tr := io.TeeReader(l.transport, rsp)

// Wait for the response header
var hdr tpm2.ResponseHeader
Expand All @@ -121,18 +145,18 @@ func (t *retrierTransport) runCommand(commandCode tpm2.CommandCode, data []byte)
tpm2.IsTPMWarning(err, tpm2.WarningTesting, commandCode) ||
tpm2.IsTPMWarning(err, tpm2.WarningRetry, commandCode)) {
time.Sleep(retryDelay)
retryDelay *= time.Duration(t.params.BackoffRate)
retryDelay *= time.Duration(l.params.BackoffRate)
continue
}

return rsp.Bytes(), nil
}
}

func (t *retrierTransport) run(r io.Reader, w io.Writer, wLen chan<- int64, wErr chan<- error) (err error) {
func (l *retrierTransportLoop) run() (err error) {
for {
cmd := new(bytes.Buffer)
tr := io.TeeReader(r, cmd)
tr := io.TeeReader(l.r, cmd)

// Wait for the next command header
var hdr tpm2.CommandHeader
Expand All @@ -157,15 +181,15 @@ func (t *retrierTransport) run(r io.Reader, w io.Writer, wLen chan<- int64, wErr
return err
}

rsp, err := t.runCommand(hdr.CommandCode, cmd.Bytes())
rsp, err := l.runCommand(hdr.CommandCode, cmd.Bytes())
switch {
case err != nil:
// Command dispatch failed, send an error to the reader
wErr <- err
l.wErr <- err
default:
// Command was executed, send the response to the reader
wLen <- int64(len(rsp))
_, err := io.Copy(w, bytes.NewReader(rsp))
l.wLen <- int64(len(rsp))
_, err := io.Copy(l.w, bytes.NewReader(rsp))
switch {
case errors.Is(err, io.ErrClosedPipe):
return nil
Expand Down

0 comments on commit 0c33912

Please sign in to comment.