From 832df1cdbada8226832e1314fb56e42344bec01a Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Mon, 15 May 2023 14:38:26 -0700 Subject: [PATCH] Protect against out of bounds access on usage updates. Signed-off-by: Derek Collison --- server/jetstream.go | 88 ++++++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/server/jetstream.go b/server/jetstream.go index 1f663708a5..50cc8a7ddc 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -1714,14 +1714,13 @@ func (a *Account) JetStreamEnabled() bool { } func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account, subject, _ string, msg []byte) { - const usageSize = 32 - // jsa.js.srv is immutable and guaranteed to no be nil, so no lock needed. s := jsa.js.srv jsa.usageMu.Lock() - if len(msg) < usageSize { - jsa.usageMu.Unlock() + defer jsa.usageMu.Unlock() + + if len(msg) < minUsageUpdateLen { s.Warnf("Ignoring remote usage update with size too short") return } @@ -1730,7 +1729,6 @@ func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account rnode = subject[li+1:] } if rnode == _EMPTY_ { - jsa.usageMu.Unlock() s.Warnf("Received remote usage update with no remote node") return } @@ -1765,21 +1763,31 @@ func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account apiTotal, apiErrors := le.Uint64(msg[16:]), le.Uint64(msg[24:]) memUsed, storeUsed := int64(le.Uint64(msg[0:])), int64(le.Uint64(msg[8:])) - // we later extended the data structure to support multiple tiers - excessRecordCnt := uint32(0) - tierName := _EMPTY_ - if len(msg) >= 44 { - excessRecordCnt = le.Uint32(msg[32:]) - length := le.Uint64(msg[36:]) - tierName = string(msg[44 : 44+length]) - msg = msg[44+length:] + // We later extended the data structure to support multiple tiers + var excessRecordCnt uint32 + var tierName string + + if len(msg) >= usageMultiTiersLen { + excessRecordCnt = le.Uint32(msg[minUsageUpdateLen:]) + length := le.Uint64(msg[minUsageUpdateLen+4:]) + // Need to protect past this point in case this is wrong. + if uint64(len(msg)) < usageMultiTiersLen+length { + s.Warnf("Received corrupt remote usage update") + return + } + tierName = string(msg[usageMultiTiersLen : usageMultiTiersLen+length]) + msg = msg[usageMultiTiersLen+length:] } updateTotal(tierName, memUsed, storeUsed) - for ; excessRecordCnt > 0 && len(msg) >= 24; excessRecordCnt-- { + for ; excessRecordCnt > 0 && len(msg) >= usageRecordLen; excessRecordCnt-- { memUsed, storeUsed := int64(le.Uint64(msg[0:])), int64(le.Uint64(msg[8:])) length := le.Uint64(msg[16:]) - tierName = string(msg[24 : 24+length]) - msg = msg[24+length:] + if uint64(len(msg)) < usageRecordLen+length { + s.Warnf("Received corrupt remote usage update on excess record") + return + } + tierName = string(msg[usageRecordLen : usageRecordLen+length]) + msg = msg[usageRecordLen+length:] updateTotal(tierName, memUsed, storeUsed) } jsa.apiTotal -= rUsage.api @@ -1788,7 +1796,6 @@ func (jsa *jsAccount) remoteUpdateUsage(sub *subscription, c *client, _ *Account rUsage.err = apiErrors jsa.apiTotal += apiTotal jsa.apiErrors += apiErrors - jsa.usageMu.Unlock() } // When we detect a skew of some sort this will verify the usage reporting is correct. @@ -1906,12 +1913,22 @@ func (jsa *jsAccount) sendClusterUsageUpdateTimer() { } } +// For usage fields. +const ( + minUsageUpdateLen = 32 + stackUsageUpdate = 72 + usageRecordLen = 24 + usageMultiTiersLen = 44 + apiStatsAndNumTiers = 20 + minUsageUpdateWindow = 250 * time.Millisecond +) + // Send updates to our account usage for this server. // jsa.usageMu lock should be held. func (jsa *jsAccount) sendClusterUsageUpdate() { // These values are absolute so we can limit send rates. now := time.Now() - if now.Sub(jsa.lupdate) < 250*time.Millisecond { + if now.Sub(jsa.lupdate) < minUsageUpdateWindow { return } jsa.lupdate = now @@ -1921,32 +1938,37 @@ func (jsa *jsAccount) sendClusterUsageUpdate() { return } // every base record contains mem/store/len(tier) as well as the tier name - l := 24 * lenUsage + l := usageRecordLen * lenUsage for tier := range jsa.usage { l += len(tier) } - if lenUsage > 0 { - // first record contains api/usage errors as well as count for extra base records - l += 20 + // first record contains api/usage errors as well as count for extra base records + l += apiStatsAndNumTiers + + var raw [stackUsageUpdate]byte + var b []byte + if l > stackUsageUpdate { + b = make([]byte, l) + } else { + b = raw[:l] } - var le = binary.LittleEndian - b := make([]byte, l) - i := 0 + var i int + var le = binary.LittleEndian for tier, usage := range jsa.usage { le.PutUint64(b[i+0:], uint64(usage.local.mem)) le.PutUint64(b[i+8:], uint64(usage.local.store)) if i == 0 { - le.PutUint64(b[i+16:], jsa.usageApi) - le.PutUint64(b[i+24:], jsa.usageErr) - le.PutUint32(b[i+32:], uint32(len(jsa.usage)-1)) - le.PutUint64(b[i+36:], uint64(len(tier))) - copy(b[i+44:], tier) - i += 44 + len(tier) + le.PutUint64(b[16:], jsa.usageApi) + le.PutUint64(b[24:], jsa.usageErr) + le.PutUint32(b[32:], uint32(len(jsa.usage)-1)) + le.PutUint64(b[36:], uint64(len(tier))) + copy(b[usageMultiTiersLen:], tier) + i = usageMultiTiersLen + len(tier) } else { le.PutUint64(b[i+16:], uint64(len(tier))) - copy(b[i+24:], tier) - i += 24 + len(tier) + copy(b[i+usageRecordLen:], tier) + i += usageRecordLen + len(tier) } } jsa.sendq.push(newPubMsg(nil, jsa.updatesPub, _EMPTY_, nil, nil, b, noCompression, false, false))