diff --git a/claim_strings.go b/claim_strings.go new file mode 100644 index 00000000..24c801c0 --- /dev/null +++ b/claim_strings.go @@ -0,0 +1,37 @@ +package jwt + +import ( + "encoding/json" + "reflect" +) + +// ClaimStrings is used for parsing claim properties that +// can be either a string or array of strings +type ClaimStrings []string + +// UnmarshalJSON implements the json package's Unmarshaler interface +func (c *ClaimStrings) UnmarshalJSON(data []byte) error { + var value interface{} + err := json.Unmarshal(data, &value) + if err != nil { + return err + } + switch v := value.(type) { + case string: + *c = ClaimStrings{v} + case []interface{}: + result := make(ClaimStrings, 0, len(v)) + for i, vv := range v { + if x, ok := vv.(string); ok { + result = append(result, x) + } else { + return &json.UnsupportedTypeError{Type: reflect.TypeOf(v[i])} + } + } + *c = result + case nil: + default: + return &json.UnsupportedTypeError{Type: reflect.TypeOf(v)} + } + return nil +} diff --git a/claim_strings_test.go b/claim_strings_test.go new file mode 100644 index 00000000..2673bd39 --- /dev/null +++ b/claim_strings_test.go @@ -0,0 +1,60 @@ +package jwt_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/dgrijalva/jwt-go/v4" +) + +var claimStringsTestData = []struct { + name string + input interface{} + output jwt.ClaimStrings + err error +}{ + { + name: "null", + input: nil, + output: nil, + }, + { + name: "single", + input: "foo", + output: jwt.ClaimStrings{"foo"}, + }, + { + name: "multi", + input: []string{"foo", "bar"}, + output: jwt.ClaimStrings{"foo", "bar"}, + }, + { + name: "invalid", + input: float64(42), + output: nil, + err: &json.UnsupportedTypeError{Type: reflect.TypeOf(float64(42))}, + }, + { + name: "invalid multi", + input: []interface{}{"foo", float64(42)}, + output: nil, + err: &json.UnsupportedTypeError{Type: reflect.TypeOf(float64(42))}, + }, +} + +func TestClaimStrings(t *testing.T) { + for _, test := range claimStringsTestData { + var r *struct { + Value jwt.ClaimStrings `json:"value"` + } + data, _ := json.Marshal(map[string]interface{}{"value": test.input}) + err := json.Unmarshal(data, &r) + if !reflect.DeepEqual(err, test.err) { + t.Errorf("[%v] Error didn't match expectation: %v != %v", test.name, test.err, err) + } + if !reflect.DeepEqual(test.output, r.Value) { + t.Errorf("[%v] Unmarshaled value didn't match expectation: %v != %v", test.name, test.output, r.Value) + } + } +} diff --git a/claims.go b/claims.go index 45918d8e..e5316b97 100644 --- a/claims.go +++ b/claims.go @@ -19,13 +19,13 @@ type Claims interface { // https://tools.ietf.org/html/rfc7519#section-4.1 // See examples for how to use this with your own claim types type StandardClaims struct { - Audience string `json:"aud,omitempty"` - ExpiresAt int64 `json:"exp,omitempty"` - ID string `json:"jti,omitempty"` - IssuedAt int64 `json:"iat,omitempty"` - Issuer string `json:"iss,omitempty"` - NotBefore int64 `json:"nbf,omitempty"` - Subject string `json:"sub,omitempty"` + Audience ClaimStrings `json:"aud,omitempty"` + ExpiresAt int64 `json:"exp,omitempty"` + ID string `json:"jti,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + Issuer string `json:"iss,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` } // Valid implements Valid from Claims @@ -94,12 +94,14 @@ func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { // ----- helpers -func verifyAud(aud string, cmp string, required bool) bool { - if aud == "" { +func verifyAud(aud ClaimStrings, cmp string, required bool) bool { + if len(aud) == 0 { return !required } - if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 { - return true + for _, audStr := range aud { + if subtle.ConstantTimeCompare([]byte(audStr), []byte(cmp)) != 0 { + return true + } } return false } diff --git a/map_claims.go b/map_claims.go index 7187a1f6..d0fdf7aa 100644 --- a/map_claims.go +++ b/map_claims.go @@ -13,8 +13,19 @@ type MapClaims map[string]interface{} // Compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyAudience(cmp string, req bool) bool { - aud, _ := m["aud"].(string) - return verifyAud(aud, cmp, req) + aud, ok := m["aud"] + if !ok { + return !req + } + + switch v := aud.(type) { + case string: + return verifyAud(ClaimStrings{v}, cmp, req) + case []string: + return verifyAud(ClaimStrings(v), cmp, req) + default: + return false + } } // Compares the exp claim against cmp.