Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v9] Fix known_hosts locking by refactoring our locks in utils/fs #16444

Merged
merged 5 commits into from Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 2 additions & 17 deletions lib/client/keystore.go
Expand Up @@ -31,8 +31,6 @@ import (

"golang.org/x/crypto/ssh"

"github.com/gofrs/flock"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/utils/keypaths"
Expand Down Expand Up @@ -555,24 +553,11 @@ func (fs *fsLocalNonSessionKeyStore) kubeCertPath(idx KeyIndex, kubename string)
return keypaths.KubeCertPath(fs.KeyDir, idx.ProxyHost, idx.Username, idx.ClusterName, kubename)
}

// acquireFileLock is trying to lock the file, until it's successful or timeout is exceeded.
// File will be created if it doesn't exist.
func acquireFileLock(filePath string, timeout time.Duration) (func() error, error) {
fileLock := flock.New(filePath)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if _, err := fileLock.TryLockContext(ctx, 10*time.Millisecond); err != nil {
return nil, err
}

return fileLock.Unlock, nil
}

// AddKnownHostKeys adds a new entry to `known_hosts` file.
func (fs *fsLocalNonSessionKeyStore) AddKnownHostKeys(hostname, proxyHost string, hostKeys []ssh.PublicKey) (retErr error) {
// We're trying to serialize our writes to the 'known_hosts' file to avoid corruption, since there
// are cases when multiple tsh instances will try to write to it.
unlock, err := acquireFileLock(fs.knownHostsPath(), 5*time.Second)
unlock, err := utils.FSTryWriteLockTimeout(context.Background(), fs.knownHostsPath(), 5*time.Second)
if err != nil {
return trace.WrapWithMessage(err, "could not acquire lock for the `known_hosts` file")
}
Expand Down Expand Up @@ -668,7 +653,7 @@ func matchesWildcard(hostname, pattern string) bool {

// GetKnownHostKeys returns all known public keys from `known_hosts`.
func (fs *fsLocalNonSessionKeyStore) GetKnownHostKeys(hostname string) (keys []ssh.PublicKey, retErr error) {
unlock, err := acquireFileLock(fs.knownHostsPath(), 5*time.Second)
unlock, err := utils.FSTryReadLockTimeout(context.Background(), fs.knownHostsPath(), 5*time.Second)
if err != nil {
return nil, trace.WrapWithMessage(err, "could not acquire lock for the `known_hosts` file")
}
Expand Down
18 changes: 10 additions & 8 deletions lib/events/filesessions/fileasync.go
Expand Up @@ -115,7 +115,6 @@ func NewUploader(cfg UploaderConfig) (*Uploader, error) {
// the upload that have been aborted.
//
// It marks corrupted session files to skip their processing.
//
type Uploader struct {
semaphore chan struct{}

Expand Down Expand Up @@ -242,7 +241,7 @@ func (u *Uploader) Scan(ctx context.Context) (*ScanStats, error) {
}
stats.Scanned++
if err := u.startUpload(ctx, fi.Name()); err != nil {
if trace.IsCompareFailed(err) {
if errors.Is(err, utils.ErrUnsuccessfulLockTry) {
u.log.Debugf("Scan is skipping recording %v that is locked by another process.", fi.Name())
continue
}
Expand Down Expand Up @@ -278,6 +277,7 @@ type upload struct {
sessionID session.ID
reader *events.ProtoReader
file *os.File
fileUnlockFn func() error
checkpointFile *os.File
}

Expand Down Expand Up @@ -323,7 +323,7 @@ func (u *upload) writeStatus(status apievents.StreamStatus) error {
func (u *upload) Close() error {
return trace.NewAggregate(
u.reader.Close(),
utils.FSUnlock(u.file),
u.fileUnlockFn(),
u.file.Close(),
utils.NilCloser(u.checkpointFile).Close(),
)
Expand Down Expand Up @@ -367,17 +367,19 @@ func (u *Uploader) startUpload(ctx context.Context, fileName string) error {
if err != nil {
return trace.ConvertSystemError(err)
}
if err := utils.FSTryWriteLock(sessionFile); err != nil {
unlock, err := utils.FSTryWriteLock(sessionFilePath)
if err != nil {
if e := sessionFile.Close(); e != nil {
u.log.WithError(e).Warningf("Failed to close %v.", fileName)
}
return trace.Wrap(err)
return trace.WrapWithMessage(err, "could not acquire file lock for %q", sessionFilePath)
}

upload := &upload{
sessionID: sessionID,
reader: events.NewProtoReader(sessionFile),
file: sessionFile,
sessionID: sessionID,
reader: events.NewProtoReader(sessionFile),
file: sessionFile,
fileUnlockFn: unlock,
}
upload.checkpointFile, err = os.OpenFile(u.checkpointFilePath(sessionID), os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions lib/events/filesessions/filestream.go
Expand Up @@ -112,11 +112,12 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload
if err != nil {
return trace.ConvertSystemError(err)
}
if err := utils.FSTryWriteLock(f); err != nil {
return trace.Wrap(err)
unlock, err := utils.FSTryWriteLock(uploadPath)
if err != nil {
return trace.WrapWithMessage(err, "could not acquire file lock for %q", uploadPath)
}
defer func() {
if err := utils.FSUnlock(f); err != nil {
if err := unlock(); err != nil {
h.WithError(err).Errorf("Failed to unlock filesystem lock.")
}
if err := f.Close(); err != nil {
Expand Down
5 changes: 3 additions & 2 deletions lib/events/uploader.go
Expand Up @@ -233,7 +233,8 @@ func (u *Uploader) uploadFile(lockFilePath string, sessionID session.ID) error {
if err != nil {
return trace.ConvertSystemError(err)
}
if err := utils.FSTryWriteLock(lockFile); err != nil {
unlock, err := utils.FSTryWriteLock(lockFilePath)
if err != nil {
return trace.Wrap(err)
}
reader, err := NewSessionArchive(u.DataDir, u.ServerID, u.Namespace, sessionID)
Expand All @@ -247,7 +248,7 @@ func (u *Uploader) uploadFile(lockFilePath string, sessionID session.ID) error {
defer u.releaseSemaphore()
defer reader.Close()
defer lockFile.Close()
defer utils.FSUnlock(lockFile)
defer unlock()

start := time.Now()
err := u.AuditLog.UploadSessionRecording(SessionRecording{
Expand Down
65 changes: 65 additions & 0 deletions lib/utils/fs.go
Expand Up @@ -17,15 +17,24 @@ limitations under the License.
package utils

import (
"context"
"errors"
"os"
"path/filepath"
"runtime"
"time"

"github.com/gofrs/flock"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/trace"
)

// ErrUnsuccessfulLockTry designates an error when we temporarily couldn't acquire lock
// (most probably it was already locked by someone else), another try might succeed.
var ErrUnsuccessfulLockTry = errors.New("could not acquire lock on the file at this time")

// EnsureLocalPath makes sure the path exists, or, if omitted results in the subpath in
// default gravity config directory, e.g.
//
Expand Down Expand Up @@ -156,3 +165,59 @@ func getHomeDir() string {
}
return ""
}

// FSTryWriteLock tries to grab write lock, returns ErrUnsuccessfulLockTry
// if lock is already acquired by someone else
func FSTryWriteLock(filePath string) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
locked, err := fileLock.TryLock()
if err != nil {
return nil, trace.ConvertSystemError(err)
}
if !locked {
return nil, trace.Retry(ErrUnsuccessfulLockTry, "")
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}

// FSTryWriteLockTimeout tries to grab write lock, it's doing it until locks is acquired, or timeout is expired,
// or context is expired.
func FSTryWriteLockTimeout(ctx context.Context, filePath string, timeout time.Duration) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
timedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if _, err := fileLock.TryLockContext(timedCtx, 10*time.Millisecond); err != nil {
return nil, trace.ConvertSystemError(err)
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}

// FSTryReadLock tries to grab write lock, returns ErrUnsuccessfulLockTry
// if lock is already acquired by someone else
func FSTryReadLock(filePath string) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
locked, err := fileLock.TryRLock()
if err != nil {
return nil, trace.ConvertSystemError(err)
}
if !locked {
return nil, trace.Retry(ErrUnsuccessfulLockTry, "")
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}

// FSTryReadLockTimeout tries to grab read lock, it's doing it until locks is acquired, or timeout is expired,
// or context is expired.
func FSTryReadLockTimeout(ctx context.Context, filePath string, timeout time.Duration) (unlock func() error, err error) {
fileLock := flock.New(getPlatformLockFilePath(filePath))
timedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if _, err := fileLock.TryRLockContext(timedCtx, 10*time.Millisecond); err != nil {
return nil, trace.ConvertSystemError(err)
}

return unlockWrapper(fileLock.Unlock, fileLock.Path()), nil
}
92 changes: 92 additions & 0 deletions lib/utils/fs_test.go
@@ -0,0 +1,92 @@
/*
Copyright 2022 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package utils

import (
"context"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestLocks(t *testing.T) {
t.Parallel()

tmpFile, err := os.CreateTemp("", "teleport-lock-test")
fp := tmpFile.Name()
t.Cleanup(func() {
_ = os.Remove(fp)
})
require.NoError(t, err)

// Can take read lock
unlock, err := FSTryReadLock(fp)
require.NoError(t, err)

require.NoError(t, unlock())

// Can take write lock
unlock, err = FSTryWriteLock(fp)
require.NoError(t, err)

// Can't take read lock while write lock is held.
unlock2, err := FSTryReadLock(fp)
require.ErrorIs(t, err, ErrUnsuccessfulLockTry)
require.Nil(t, unlock2)

// Can't take write lock while another write lock is held.
unlock2, err = FSTryWriteLock(fp)
require.ErrorIs(t, err, ErrUnsuccessfulLockTry)
require.Nil(t, unlock2)

require.NoError(t, unlock())

unlock, err = FSTryReadLock(fp)
require.NoError(t, err)

// Can take second read lock on the same file.
unlock2, err = FSTryReadLock(fp)
require.NoError(t, err)

require.NoError(t, unlock())
require.NoError(t, unlock2())

// Can take read lock with timeout
unlock, err = FSTryReadLockTimeout(context.Background(), fp, time.Second)
require.NoError(t, err)
require.NoError(t, unlock())

// Can take write lock with timeout
unlock, err = FSTryWriteLockTimeout(context.Background(), fp, time.Second)
require.NoError(t, err)

// Fails because timeout is exceeded, since file is already locked.
unlock2, err = FSTryWriteLockTimeout(context.Background(), fp, time.Millisecond)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, unlock2)

// Fails because context is expired while waiting for timeout.
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
defer cancel()
unlock2, err = FSTryWriteLockTimeout(ctx, fp, time.Hour*1000)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, unlock2)

require.NoError(t, unlock())
}
48 changes: 8 additions & 40 deletions lib/utils/fs_unix.go
Expand Up @@ -19,48 +19,16 @@ limitations under the License.

package utils

import (
"os"
"syscall"

"github.com/gravitational/trace"
)

// FSWriteLock grabs Flock-style filesystem lock on an open file
// in exclusive mode.
func FSWriteLock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
return trace.ConvertSystemError(err)
}
return nil
// On non-windows we just lock the target file itself.
func getPlatformLockFilePath(path string) string {
return path
}

// FSTryWriteLock tries to grab write lock, returns CompareFailed
// if lock is already acquired
func FSTryWriteLock(f *os.File) error {
err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err != nil {
if err == syscall.EWOULDBLOCK {
return trace.CompareFailed("lock %v is acquired by another process", f.Name())
func unlockWrapper(unlockFn func() error, path string) func() error {
return func() error {
if unlockFn == nil {
return nil
}
return trace.ConvertSystemError(err)
}
return nil
}

// FSReadLock grabs Flock-style filesystem lock on an open file
// in read (shared) mode
func FSReadLock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_SH); err != nil {
return trace.ConvertSystemError(err)
}
return nil
}

// FSUnlock unlcocks Flock-style filesystem lock
func FSUnlock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
return trace.ConvertSystemError(err)
return unlockFn()
}
return nil
}