Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Watch the htpasswd file for changes and update the htpasswdMap #1701

Merged
merged 12 commits into from Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion oauthproxy.go
Expand Up @@ -113,7 +113,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
var err error
basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile)
if err != nil {
return nil, fmt.Errorf("could not load htpasswdfile: %v", err)
return nil, fmt.Errorf("could not load htpasswd file: %v", err)
}
}

Expand Down
99 changes: 77 additions & 22 deletions pkg/authentication/basic/htpasswd.go
Expand Up @@ -8,15 +8,18 @@ import (
"fmt"
"io"
"os"
"sync"

"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/watcher"
"golang.org/x/crypto/bcrypt"
)

// htpasswdMap represents the structure of an htpasswd file.
// Passwords must be generated with -B for bcrypt or -s for SHA1.
type htpasswdMap struct {
users map[string]interface{}
rwm sync.RWMutex
}

// bcryptPass is used to identify bcrypt passwords in the
Expand All @@ -30,59 +33,111 @@ type sha1Pass string
// NewHTPasswdValidator constructs an httpasswd based validator from the file
// at the path given.
func NewHTPasswdValidator(path string) (Validator, error) {
h := &htpasswdMap{users: make(map[string]interface{})}

err := h.loadHTPasswdFile(path)
if err != nil {
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
return nil, err
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please wrap errors you are returning with additional context

Suggested change
return nil, err
return nil, fmt.Errorf("could not load htpasswd file: %v", err)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have the context in the oauthproxy.go file -> https://github.com/oauth2-proxy/oauth2-proxy/blob/master/oauthproxy.go#L114-L117.
We just need the value of the error retrieved from loadHTPasswdFile or WatchFileForUpdates.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically, if you add context at every level, it's easy to see the exact error paths. Else you have to dig into each function call within this function to see which child error came up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the error message from oauthproxy.go and added the context for load and watch.

}

watcher.WatchFileForUpdates(path, nil, func() {
err := h.loadHTPasswdFile(path)
if err != nil {
logger.Errorf("%v: no changes were made to the current htpasswd map", err)
}
})

return h, nil
}

// loadHTPasswdFile loads htpasswd entries from an io.Reader (an opened file) into a htpasswdMap.
func (h *htpasswdMap) loadHTPasswdFile(filename string) error {
// We allow HTPasswd location via config options
r, err := os.Open(path) // #nosec G304
r, err := os.Open(filename) // #nosec G304
if err != nil {
return nil, fmt.Errorf("could not open htpasswd file: %v", err)
return fmt.Errorf("could not open htpasswd file: %v", err)
}
defer func(c io.Closer) {
cerr := c.Close()
if cerr != nil {
logger.Fatalf("error closing the htpasswd file: %v", cerr)
}
}(r)
return newHtpasswd(r)
}

// newHtpasswd consctructs an htpasswd from an io.Reader (an opened file).
func newHtpasswd(file io.Reader) (*htpasswdMap, error) {
csvReader := csv.NewReader(file)
csvReader := csv.NewReader(r)
csvReader.Comma = ':'
csvReader.Comment = '#'
csvReader.TrimLeadingSpace = true

records, err := csvReader.ReadAll()
if err != nil {
return nil, fmt.Errorf("could not read htpasswd file: %v", err)
logger.Fatalf("could not read htpasswd file: %v", err)
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
}

return createHtpasswdMap(records)
updated, err := createHtpasswdMap(records)
if err != nil {
return fmt.Errorf("htpasswd entries error: %v", err)
}

h.rwm.Lock()
h.users = updated.users
h.rwm.Unlock()

return nil
}

// createHtasswdMap constructs an htpasswdMap from the given records
func createHtpasswdMap(records [][]string) (*htpasswdMap, error) {
h := &htpasswdMap{users: make(map[string]interface{})}
var invalidRecords, invalidEntries []string
for _, record := range records {
user, realPassword := record[0], record[1]
shaPrefix := realPassword[:5]
if shaPrefix == "{SHA}" {
h.users[user] = sha1Pass(realPassword[5:])
continue
// If a record is invalid or malformed don't panic with index out of range,
// return a formatted error.
lr := len(record)
switch {
case lr == 2:
user, realPassword := record[0], record[1]
invalidEntries = passShaOrBcrypt(h, user, realPassword)
case lr == 1, lr > 2:
invalidRecords = append(invalidRecords, record[0])
}
}

bcryptPrefix := realPassword[:4]
if bcryptPrefix == "$2a$" || bcryptPrefix == "$2b$" || bcryptPrefix == "$2x$" || bcryptPrefix == "$2y$" {
h.users[user] = bcryptPass(realPassword)
continue
}
if len(invalidRecords) > 0 {
return h, fmt.Errorf("invalid htpasswd record(s) %+q", invalidRecords)
}

if len(invalidEntries) > 0 {
return h, fmt.Errorf("'%+q' user(s) could not be added: invalid password, must be a SHA or bcrypt entry", invalidEntries)
}

// Password is neither sha1 or bcrypt
// TODO(JoelSpeed): In the next breaking release, make this return an error.
logger.Errorf("Invalid htpasswd entry for %s. Must be a SHA or bcrypt entry.", user)
if len(h.users) == 0 {
logger.Fatal("could not construct htpasswdMap: htpasswd file doesn't contain a single valid user entry")
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
}

return h, nil
}

// passShaOrBcrypt checks if a htpasswd entry is valid and the password is encrypted with SHA or bcrypt.
// Valid user entries are saved in the htpasswdMap, invalid records are reurned.
func passShaOrBcrypt(h *htpasswdMap, user, password string) (invalidEntries []string) {
passLen := len(password)
switch {
case passLen > 6 && password[:5] == "{SHA}":
h.users[user] = sha1Pass(password[5:])
case passLen > 5 &&
(password[:4] == "$2b$" ||
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
password[:4] == "$2y$" ||
password[:4] == "$2x$" ||
password[:4] == "$2a$"):
h.users[user] = bcryptPass(password)
default:
invalidEntries = append(invalidEntries, user)
}

return invalidEntries
}

// Validate checks a users password against the htpasswd entries
func (h *htpasswdMap) Validate(user string, password string) bool {
realPassword, exists := h.users[user]
Expand Down
67 changes: 67 additions & 0 deletions pkg/authentication/basic/htpasswd_test.go
@@ -1,6 +1,8 @@
package basic

import (
"os"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
Expand Down Expand Up @@ -91,6 +93,71 @@ var _ = Describe("HTPasswd Suite", func() {
Expect(validator).To(BeNil())
})
})

Context("htpasswd file is updated", func() {
const filePathPrefix = "htpasswd-file-updated-"
const adminUserHtpasswdEntry = "admin:$2y$05$SXWrNM7ldtbRzBvUC3VXyOvUeiUcP45XPwM93P5eeGOEPIiAZmJjC"
const user1HtpasswdEntry = "user1:$2y$05$/sZYJOk8.3Etg4V6fV7puuXfCJLmV5Q7u3xvKpjBSJUka.t2YtmmG"
var fileNames []string

AfterSuite(func() {
for _, v := range fileNames {
err := os.Remove(v)
Expect(err).ToNot(HaveOccurred())
}

})

htpasswdMap := func(entry, otherEntry string, remove bool) *htpasswdMap {
var validator Validator
var file *os.File
var err error

// Create a temporary file with at least one entry
file, err = os.CreateTemp("", filePathPrefix)
Expect(err).ToNot(HaveOccurred())
_, err = file.WriteString(entry + "\n")
Expect(err).ToNot(HaveOccurred())

validator, err = NewHTPasswdValidator(file.Name())
Expect(err).ToNot(HaveOccurred())

htpasswd, ok := validator.(*htpasswdMap)
Expect(ok).To(BeTrue())

if remove {
// Overwrite the original file with another entry
err = os.WriteFile(file.Name(), []byte(otherEntry+"\n"), 0644)
Expect(err).ToNot(HaveOccurred())
} else {
// Add another entry to the original file in append mode
_, err = file.WriteString(otherEntry + "\n")
Expect(err).ToNot(HaveOccurred())
}

err = file.Close()
Expect(err).ToNot(HaveOccurred())

fileNames = append(fileNames, file.Name())

return htpasswd
}

htpasswdAdd := htpasswdMap(adminUserHtpasswdEntry, user1HtpasswdEntry, false)
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
It("htpasswd entry is added", func() {
Expect(len(htpasswdAdd.users)).To(Equal(2))
Expect(htpasswdAdd.Validate(adminUser, adminPassword)).To(BeTrue())
Expect(htpasswdAdd.Validate(user1, user1Password)).To(BeTrue())
})

htpasswdRemove := htpasswdMap(adminUserHtpasswdEntry, user1HtpasswdEntry, true)
It("htpasswd entry is removed", func() {
Expect(len(htpasswdRemove.users)).To(Equal(1))
Expect(htpasswdRemove.Validate(adminUser, adminPassword)).To(BeFalse())
Expect(htpasswdRemove.Validate(user1, user1Password)).To(BeTrue())
})

})
})
})
})
82 changes: 82 additions & 0 deletions pkg/watcher/watcher.go
@@ -0,0 +1,82 @@
package watcher

import (
"os"
"path/filepath"
"time"

"github.com/fsnotify/fsnotify"

"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
)

// WatchForUpdates performs an action every time a file on disk is updated
func WatchFileForUpdates(filename string, done <-chan bool, action func()) {
filename = filepath.Clean(filename)
watcher, err := fsnotify.NewWatcher()
if err != nil {
logger.Fatalf("failed to create watcher for '%s': %s", filename, err)
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
}

go func() {
defer watcher.Close()

for {
select {
case <-done:
logger.Printf("shutting down watcher for: %s", filename)
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
return
case event, ok := <-watcher.Events:
if !ok { // 'Events' channel is closed
logger.Errorf("error: cannot start the watcher, events channel is closed")
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
return
}
filterEvent(watcher, event, filename, action)
case err = <-watcher.Errors:
logger.Errorf("error watching '%s': %s", filename, err)
}
}
}()
if err = watcher.Add(filename); err != nil {
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
logger.Fatalf("failed to add '%s' to watcher: %v", filename, err)
}
logger.Printf("watching '%s' for updates", filename)
}

// Filter file operations based on the events sent by the watcher.
// Execute the action() function when the following conditions are met:
// - the file is modified or created
// - the real path of the file was changed (Kubernetes ConfigMap/Secret)
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
func filterEvent(watcher *fsnotify.Watcher, event fsnotify.Event, filename string, action func()) {
switch filepath.Clean(event.Name) == filename {
case event.Op&(fsnotify.Remove|fsnotify.Rename) != 0:
logger.Printf("watching interrupted on event: %s", event)
WaitForReplacement(filename, event.Op, watcher)
action()
case event.Op&(fsnotify.Create|fsnotify.Write) != 0:
logger.Printf("reloading after event: %s", event)
action()
default:
logger.Printf("current event: %s", event)
aiciobanu marked this conversation as resolved.
Show resolved Hide resolved
}
}

// WaitForReplacement waits for a file to exist on disk and then starts a watch
// for the file
func WaitForReplacement(filename string, op fsnotify.Op, watcher *fsnotify.Watcher) {
const sleepInterval = 50 * time.Millisecond

// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
if op&fsnotify.Chmod != 0 {
time.Sleep(sleepInterval)
}
for {
if _, err := os.Stat(filename); err == nil {
if err := watcher.Add(filename); err == nil {
logger.Printf("watching resumed for '%s'", filename)
return
}
}
time.Sleep(sleepInterval)
}
}
3 changes: 2 additions & 1 deletion validator.go
Expand Up @@ -9,6 +9,7 @@ import (
"unsafe"

"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/watcher"
)

// UserMap holds information from the authenticated emails file
Expand All @@ -26,7 +27,7 @@ func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
atomic.StorePointer(&um.m, unsafe.Pointer(&m)) // #nosec G103
if usersFile != "" {
logger.Printf("using authenticated emails file %s", usersFile)
WatchForUpdates(usersFile, done, func() {
watcher.WatchFileForUpdates(usersFile, done, func() {
um.LoadAuthenticatedEmailsFile()
onUpdate()
})
Expand Down