diff --git a/config.go b/config.go index 1fc1974..a4b36b0 100644 --- a/config.go +++ b/config.go @@ -3,7 +3,9 @@ package main import ( "errors" "fmt" + "net/url" "os" + "regexp" "strconv" "strings" @@ -17,7 +19,7 @@ var ( cfg = &jwt.Config{} listenAddr = flag.String("listen-addr", "0.0.0.0", "Listen address") listenPort = flag.String("listen-port", "", "Listen port (default: 80 for HTTP or 443 for HTTPS)") - audiences = flag.String("audiences", "", "Comma separated list of JWT Audiences (format: https://yourdomain or https://yourdomain:port)") + audiences = flag.String("audiences", "", "Comma-separated list of JWT Audiences (elements can be URLs like \"https://exammple.com:port\" or regular expressions like \"/^https://example\\.com:port$/\" if you enclose them in slashes)") publicKeysPath = flag.String("public-keys", "", "Path to public keys file (optional)") tlsCertPath = flag.String("tls-cert", "", "Path to TLS server's, intermediate's and CA's PEM certificate (optional)") tlsKeyPath = flag.String("tls-key", "", "Path to TLS server's PEM key file (optional)") @@ -35,7 +37,7 @@ func initConfig() error { if len(*audiences) == 0 { return errors.New("You must specify --audiences") } - if err := initAudiences(strings.Split(*audiences, ",")); err != nil { + if err := initAudiences(*audiences); err != nil { return err } if err := initPublicKeys(*publicKeysPath); err != nil { @@ -58,15 +60,47 @@ func initServerPort() error { return nil } -func initAudiences(rawURLs []string) error { - for _, rawURL := range rawURLs { - aud, err := jwt.ParseAudience(rawURL) +func initAudiences(audiences string) error { + str, err := extractAudiencesRegexp(audiences) + if err != nil { + return err + } + re, err := regexp.Compile(str) + if err != nil { + return fmt.Errorf("Invalid audiences regular expression %q (%v)", str, err) + } + cfg.MatchAudiences = re + return nil +} + +func extractAudiencesRegexp(audiences string) (string, error) { + var strs []string + for _, audience := range strings.Split(audiences, ",") { + str, err := extractAudienceRegexp(audience) if err != nil { - return fmt.Errorf("Invalid audience %q (%v)", rawURL, err) + return "", err } - cfg.Audiences = append(cfg.Audiences, aud) + strs = append(strs, str) } - return nil + return strings.Join(strs, "|"), nil +} + +func extractAudienceRegexp(audience string) (string, error) { + if strings.HasPrefix(audience, "/") && strings.HasSuffix(audience, "/") { + if len(audience) < 3 { + return "", fmt.Errorf("Invalid audiences regular expression %q", audience) + } + return audience[1 : len(audience)-1], nil + } + return parseRawAudience(audience) +} + +func parseRawAudience(audience string) (string, error) { + aud, err := jwt.ParseAudience(audience) + if err != nil { + return "", fmt.Errorf("Invalid audience %q (%v)", audience, err) + } + return fmt.Sprintf("^%s$", regexp.QuoteMeta((*url.URL)(aud).String())), nil } func initPublicKeys(filePath string) error { diff --git a/jwt/claims.go b/jwt/claims.go index 4992a44..922417b 100644 --- a/jwt/claims.go +++ b/jwt/claims.go @@ -26,7 +26,7 @@ func (c Claims) Valid() error { if err != nil { return fmt.Errorf("Invalid audience %q: %v", c.Audience, err) } - if !c.cfg.containsAudience(aud) { + if !c.cfg.matchesAudience(aud) { return fmt.Errorf("Unexpected audience: %q", c.Audience) } return nil diff --git a/jwt/config.go b/jwt/config.go index f6b3add..598c806 100644 --- a/jwt/config.go +++ b/jwt/config.go @@ -1,23 +1,22 @@ package jwt -import "errors" +import ( + "errors" + "net/url" + "regexp" +) // Config specifies the parameters for which to perform validation of JWT // tokens in requests against. type Config struct { - PublicKeys map[string]PublicKey - Audiences []*Audience + PublicKeys map[string]PublicKey + MatchAudiences *regexp.Regexp } // Validate validates the Configuration. func (cfg *Config) Validate() error { - if len(cfg.Audiences) == 0 { - return errors.New("No audiences defined") - } - for _, aud := range cfg.Audiences { - if err := aud.Validate(); err != nil { - return err - } + if cfg.MatchAudiences == nil { + return errors.New("No audiences to match defined") } if len(cfg.PublicKeys) == 0 { return errors.New("No public keys defined") @@ -25,11 +24,6 @@ func (cfg *Config) Validate() error { return nil } -func (cfg *Config) containsAudience(aud *Audience) bool { - for _, aud2 := range cfg.Audiences { - if *aud == *aud2 { - return true - } - } - return false +func (cfg *Config) matchesAudience(aud *Audience) bool { + return cfg.MatchAudiences.MatchString((*url.URL)(aud).String()) } diff --git a/main.go b/main.go index 1a9b500..91c5255 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,8 @@ func main() { log.Fatal(err) } + log.Printf("Matching audiences: %s\n", cfg.MatchAudiences) + http.HandleFunc("/auth", authHandler) http.HandleFunc("/healthz", healthzHandler)