diff --git a/ee/tuf/autoupdate.go b/ee/tuf/autoupdate.go index 8f4d6fec4..4b78226c8 100644 --- a/ee/tuf/autoupdate.go +++ b/ee/tuf/autoupdate.go @@ -91,7 +91,7 @@ type TufAutoupdater struct { updateChannel string pinnedVersions map[autoupdatableBinary]string // maps the binaries to their pinned versions pinnedVersionGetters map[autoupdatableBinary]func() string // maps the binaries to the knapsack function to retrieve updated pinned versions - initialDelayLock *sync.Mutex + initialDelayEnd time.Time updateLock *sync.Mutex interrupt chan struct{} interrupted bool @@ -130,7 +130,7 @@ func NewTufAutoupdater(ctx context.Context, k types.Knapsack, metadataHttpClient binaryLauncher: func() string { return k.PinnedLauncherVersion() }, binaryOsqueryd: func() string { return k.PinnedOsquerydVersion() }, }, - initialDelayLock: &sync.Mutex{}, + initialDelayEnd: time.Now().Add(k.AutoupdateInitialDelay()), updateLock: &sync.Mutex{}, osquerier: osquerier, osquerierRetryInterval: 30 * time.Second, @@ -218,18 +218,15 @@ func DefaultLibraryDirectory(rootDirectory string) string { // we store them in. func (ta *TufAutoupdater) Execute() (err error) { // Delay startup, if initial delay is set - ta.initialDelayLock.Lock() // prevent updates during delay select { case <-ta.interrupt: ta.slogger.Log(context.TODO(), slog.LevelDebug, "received external interrupt during initial delay, stopping", ) - ta.initialDelayLock.Unlock() return nil case <-time.After(ta.knapsack.AutoupdateInitialDelay()): break } - ta.initialDelayLock.Unlock() // For now, tidy the library on startup. In the future, we will tidy the library // earlier, after version selection. @@ -281,9 +278,10 @@ func (ta *TufAutoupdater) Interrupt(_ error) { // Do satisfies the actionqueue.actor interface; it allows the control server to send // requests down to autoupdate immediately. func (ta *TufAutoupdater) Do(data io.Reader) error { - if !ta.initialDelayLock.TryLock() { + if time.Now().Before(ta.initialDelayEnd) { ta.slogger.Log(context.TODO(), slog.LevelWarn, "received update request during initial delay, discarding", + "initial_delay_end", ta.initialDelayEnd.UTC().Format(time.RFC3339), ) // We don't return an error because there's no need for the actionqueue to retry this request -- // we're going to perform an autoupdate check as soon as we exit the delay anyway. @@ -377,7 +375,7 @@ func (ta *TufAutoupdater) FlagsChanged(flagKeys ...keys.FlagKey) { } // No updates, or we're in the initial delay - if len(binariesToCheckForUpdate) == 0 || !ta.initialDelayLock.TryLock() { + if len(binariesToCheckForUpdate) == 0 || time.Now().Before(ta.initialDelayEnd) { return } diff --git a/ee/tuf/autoupdate_test.go b/ee/tuf/autoupdate_test.go index b96380d17..c93ca0ba8 100644 --- a/ee/tuf/autoupdate_test.go +++ b/ee/tuf/autoupdate_test.go @@ -43,6 +43,7 @@ func TestNewTufAutoupdater(t *testing.T) { mockKnapsack.On("MirrorServerURL").Return("https://example.com") mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) mockKnapsack.On("UpdateChannel").Return("nightly") + mockKnapsack.On("AutoupdateInitialDelay").Return(0 * time.Second) mockKnapsack.On("PinnedLauncherVersion").Return("") mockKnapsack.On("PinnedOsquerydVersion").Return("") mockKnapsack.On("RegisterChangeObserver", mock.Anything, keys.UpdateChannel, keys.PinnedLauncherVersion, keys.PinnedOsquerydVersion).Return() @@ -587,6 +588,7 @@ func TestDo(t *testing.T) { mockKnapsack.On("UpdateChannel").Return("nightly") mockKnapsack.On("PinnedLauncherVersion").Return("") mockKnapsack.On("PinnedOsquerydVersion").Return("") + mockKnapsack.On("AutoupdateInitialDelay").Return(0 * time.Second) mockKnapsack.On("AutoupdateErrorsStore").Return(s) mockKnapsack.On("TufServerURL").Return(tufServerUrl) mockKnapsack.On("UpdateDirectory").Return("") @@ -809,6 +811,7 @@ func TestFlagsChanged_UpdateChannelChanged(t *testing.T) { mockKnapsack.On("TufServerURL").Return(tufServerUrl) mockKnapsack.On("UpdateDirectory").Return("") mockKnapsack.On("MirrorServerURL").Return("https://example.com") + mockKnapsack.On("AutoupdateInitialDelay").Return(0 * time.Second) mockKnapsack.On("LocalDevelopmentPath").Return("").Maybe() mockQuerier := newMockQuerier(t) mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) @@ -870,6 +873,7 @@ func TestFlagsChanged_PinnedVersionChanged(t *testing.T) { mockKnapsack.On("UpdateDirectory").Return("") mockKnapsack.On("MirrorServerURL").Return("https://example.com") mockKnapsack.On("LocalDevelopmentPath").Return("").Maybe() + mockKnapsack.On("AutoupdateInitialDelay").Return(0 * time.Second) mockQuerier := newMockQuerier(t) mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) mockKnapsack.On("UpdateChannel").Return("nightly")