diff --git a/v8/gssapi/sequence.go b/v8/gssapi/sequence.go new file mode 100644 index 00000000..96cce70e --- /dev/null +++ b/v8/gssapi/sequence.go @@ -0,0 +1,91 @@ +package gssapi + +import ( + "errors" + "math" + "sync" +) + +var ( + errDuplicateToken = errors.New("duplicate per-message token detected") + errOldToken = errors.New("timed-out per-message token detected") + errUnseqToken = errors.New("reordered (early) per-message token detected") + errGapToken = errors.New("skipped predecessor token(s) detected") +) + +// SequenceState tracks previously seen sequence numbers for message replay +// and/or sequence protection +type SequenceState struct { + m sync.Mutex + doReplay bool + doSequence bool + base uint64 + next uint64 + receiveMask uint64 + sequenceMask uint64 +} + +// NewSequenceState returns a new SequenceState seeded with sequenceNumber +// with doReplay and doSequence controlling replay and sequence protection +// respectively and wide controlling whether sequence numbers are expected to +// wrap at a 32- or 64-bit boundary. +func NewSequenceState(sequenceNumber uint64, doReplay, doSequence, wide bool) *SequenceState { + ss := &SequenceState{ + doReplay: doReplay, + doSequence: doSequence, + base: sequenceNumber, + } + if wide { + ss.sequenceMask = math.MaxUint64 + } else { + ss.sequenceMask = math.MaxUint32 + } + return ss +} + +// Check the next sequence number. Sequence protection requires the sequence +// number to increase sequentially with no duplicates or out of order delivery. +// Replay protection relaxes these restrictions to permit limited out of order +// delivery. +func (ss *SequenceState) Check(sequenceNumber uint64) error { + if !ss.doReplay && !ss.doSequence { + return nil + } + + ss.m.Lock() + defer ss.m.Unlock() + + relativeSequenceNumber := (sequenceNumber - ss.base) & ss.sequenceMask + + if relativeSequenceNumber >= ss.next { + offset := relativeSequenceNumber - ss.next + ss.receiveMask = ss.receiveMask<<(offset+1) | 1 + ss.next = (relativeSequenceNumber + 1) & ss.sequenceMask + + if offset > 0 && ss.doSequence { + return errGapToken + } + + return nil + } + + offset := ss.next - relativeSequenceNumber + + if offset > 64 { + if ss.doSequence { + return errUnseqToken + } + return errOldToken + } + + bit := uint64(1) << (offset - 1) + if ss.doReplay && ss.receiveMask&bit != 0 { + return errDuplicateToken + } + ss.receiveMask |= bit + if ss.doSequence { + return errUnseqToken + } + + return nil +} diff --git a/v8/gssapi/sequence_test.go b/v8/gssapi/sequence_test.go new file mode 100644 index 00000000..40f0d2f6 --- /dev/null +++ b/v8/gssapi/sequence_test.go @@ -0,0 +1,115 @@ +package gssapi + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func makeRange(min, max uint64) []uint64 { + a := make([]uint64, max-min+1) + for i := range a { + a[i] = min + uint64(i) + } + return a +} + +func TestSequenceState(t *testing.T) { + tables := map[string]struct { + base uint64 + doReplay bool + doSequence bool + wide bool + sequence []uint64 + err error + }{ + "noop": { + 0, + false, + false, + false, + makeRange(0, 64), + nil, + }, + "ok": { + 0, + true, + true, + true, + makeRange(0, 64), + nil, + }, + "replay skip": { + 0, + true, + false, + true, + append(makeRange(0, 64), 66), + nil, + }, + "sequence skip": { + 0, + false, + true, + true, + append(makeRange(0, 64), 66), + errGapToken, + }, + "replay too old": { + 0, + true, + false, + true, + append(makeRange(0, 64), 0), + errOldToken, + }, + "sequence too old": { + 0, + false, + true, + true, + append(makeRange(0, 64), 0), + errUnseqToken, + }, + "replay duplicate": { + 0, + true, + false, + true, + append(makeRange(0, 64), 64), + errDuplicateToken, + }, + "sequence duplicate": { + 0, + false, + true, + true, + append(makeRange(0, 64), 64), + errUnseqToken, + }, + "replay out of order": { + 0, + true, + false, + true, + append(makeRange(0, 64), 66, 65), + nil, + }, + } + + for name, table := range tables { + t.Run(name, func(t *testing.T) { + ss := NewSequenceState(table.base, table.doReplay, table.doSequence, table.wide) + + var err error + for _, next := range table.sequence { + err = ss.Check(next) + if err != nil { + break + } + } + + assert.Equal(t, table.err, err) + }) + } +}