Skip to content

Commit

Permalink
Include header when signing v2jwt (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiashanel committed Aug 17, 2020
1 parent d7aee04 commit b9df3db
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 76 deletions.
2 changes: 1 addition & 1 deletion v2/account_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (a *AccountClaims) Encode(pair nkeys.KeyPair) (string, error) {
sort.Sort(a.Exports)
sort.Sort(a.Imports)
a.Type = AccountClaim
return a.ClaimsData.Encode(pair, a)
return a.ClaimsData.encode(pair, a)
}

// DecodeAccountClaims decodes account claims from a JWT string
Expand Down
2 changes: 1 addition & 1 deletion v2/activation_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (a *ActivationClaims) Encode(pair nkeys.KeyPair) (string, error) {
return "", errors.New("expected subject to be an account")
}
a.Type = ActivationClaim
return a.ClaimsData.Encode(pair, a)
return a.ClaimsData.encode(pair, a)
}

// DecodeActivationClaims tries to create an activation claim from a JWT string
Expand Down
32 changes: 23 additions & 9 deletions v2/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ type Claims interface {
Payload() interface{}
String() string
Validate(vr *ValidationResults)
Verify(payload string, sig []byte) bool
ClaimType() ClaimType

verify(payload string, sig []byte) bool
updateVersion()
}

Expand Down Expand Up @@ -102,6 +103,10 @@ func (c *ClaimsData) doEncode(header *Header, kp nkeys.KeyPair, claim Claims) (s
return "", errors.New("keypair is required")
}

if c != claim.Claims() {
return "", errors.New("claim and claim data do not match")
}

if c.Subject == "" {
return "", errors.New("subject is not set")
}
Expand Down Expand Up @@ -150,7 +155,7 @@ func (c *ClaimsData) doEncode(header *Header, kp nkeys.KeyPair, claim Claims) (s

c.Issuer = issuerBytes
c.IssuedAt = time.Now().UTC().Unix()

c.ID = "" // to create a repeatable hash
c.ID, err = c.hash()
if err != nil {
return "", err
Expand All @@ -163,12 +168,21 @@ func (c *ClaimsData) doEncode(header *Header, kp nkeys.KeyPair, claim Claims) (s
return "", err
}

sig, err := kp.Sign([]byte(payload))
if err != nil {
return "", err
toSign := fmt.Sprintf("%s.%s", h, payload)
eSig := ""
if header.Algorithm == AlgorithmNkeyOld {
return "", errors.New(AlgorithmNkeyOld + " not supported to write jwtV2")
} else if header.Algorithm == AlgorithmNkey {
sig, err := kp.Sign([]byte(toSign))
if err != nil {
return "", err
}
eSig = encodeToString(sig)
} else {
return "", errors.New(header.Algorithm + " not supported to write jwtV2")
}
eSig := encodeToString(sig)
return fmt.Sprintf("%s.%s.%s", h, payload, eSig), nil
// hash need no padding
return fmt.Sprintf("%s.%s", toSign, eSig), nil
}

func (c *ClaimsData) hash() (string, error) {
Expand All @@ -183,7 +197,7 @@ func (c *ClaimsData) hash() (string, error) {

// Encode encodes a claim into a JWT token. The claim is signed with the
// provided nkey's private key
func (c *ClaimsData) Encode(kp nkeys.KeyPair, payload Claims) (string, error) {
func (c *ClaimsData) encode(kp nkeys.KeyPair, payload Claims) (string, error) {
return c.doEncode(&Header{TokenTypeJwt, AlgorithmNkey}, kp, payload)
}

Expand All @@ -209,7 +223,7 @@ func parseClaims(s string, target Claims) error {
// the claims portion of the token and the public key in the claim.
// Client code need to insure that the public key in the
// claim is trusted.
func (c *ClaimsData) Verify(payload string, sig []byte) bool {
func (c *ClaimsData) verify(payload string, sig []byte) bool {
// decode the public key
kp, err := nkeys.FromPublicKey(c.Issuer)
if err != nil {
Expand Down
29 changes: 18 additions & 11 deletions v2/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func Decode(token string) (Claims, error) {
if err != nil {
return nil, err
}
claim, err := loadClaims(data)
ver, claim, err := loadClaims(data)
if err != nil {
return nil, err
}
Expand All @@ -81,8 +81,15 @@ func Decode(token string) (Claims, error) {
if err != nil {
return nil, err
}
if !claim.Verify(chunks[1], sig) {
return nil, errors.New("claim failed signature verification")

if ver <= 1 {
if !claim.verify(chunks[1], sig) {
return nil, errors.New("claim failed V1 signature verification")
}
} else {
if !claim.verify(token[:len(chunks[0])+len(chunks[1])+1], sig) {
return nil, errors.New("claim failed V2 signature verification")
}
}

prefixes := claim.ExpectedPrefixes()
Expand Down Expand Up @@ -112,14 +119,14 @@ func Decode(token string) (Claims, error) {
return claim, nil
}

func loadClaims(data []byte) (Claims, error) {
func loadClaims(data []byte) (int, Claims, error) {
var id identifier
if err := json.Unmarshal(data, &id); err != nil {
return nil, err
return -1, nil, err
}

if id.Version() > libVersion {
return nil, errors.New("JWT was generated by a newer version ")
return -1, nil, errors.New("JWT was generated by a newer version ")
}

var claim Claims
Expand All @@ -134,16 +141,16 @@ func loadClaims(data []byte) (Claims, error) {
case ActivationClaim:
claim, err = loadActivation(data, id.Version())
case "cluster":
return nil, errors.New("ClusterClaims are not supported")
return -1, nil, errors.New("ClusterClaims are not supported")
case "server":
return nil, errors.New("ServerClaims are not supported")
return -1, nil, errors.New("ServerClaims are not supported")
default:
var gc GenericClaims
if err := json.Unmarshal(data, &gc); err != nil {
return nil, err
return -1, nil, err
}
return &gc, nil
return -1, &gc, nil
}

return claim, err
return id.Version(), claim, err
}
75 changes: 36 additions & 39 deletions v2/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,16 @@ func TestBadAlgo(t *testing.T) {
c := NewGenericClaims(publicKey(createUserNKey(t), t))
c.Data["foo"] = "bar"

token, err := c.doEncode(&h, kp, c)
if err != nil {
t.Fatal(err)
if _, err := c.doEncode(&h, kp, c); err == nil {
t.Fatal("expected an error due to bad algorithm")
}

claim, err := DecodeGeneric(token)
if claim != nil {
t.Fatal("non nil claim on bad token")
}

if err == nil {
t.Fatal("nil error on bad token")
}
h = Header{TokenTypeJwt, AlgorithmNkeyOld}
c = NewGenericClaims(publicKey(createUserNKey(t), t))
c.Data["foo"] = "bar"

if err.Error() != fmt.Sprintf("unexpected %q algorithm", "foobar") {
t.Fatal("expected unexpected algorithm")
if _, err := c.doEncode(&h, kp, c); err == nil {
t.Fatal("expected an error due to bad algorithm")
}
}

Expand Down Expand Up @@ -150,30 +144,33 @@ func TestBadJWT(t *testing.T) {

func TestBadSignature(t *testing.T) {
kp := createAccountNKey(t)

h := Header{TokenTypeJwt, AlgorithmNkey}
c := NewGenericClaims(publicKey(createUserNKey(t), t))
c.Data["foo"] = "bar"

token, err := c.doEncode(&h, kp, c)
if err != nil {
t.Fatal(err)
}

token = token + "A"

claim, err := DecodeGeneric(token)
if claim != nil {
t.Fatal("non nil claim on bad token")
}

if err == nil {
t.Fatal("nil error on bad token")
}

if err.Error() != "claim failed signature verification" {
m := fmt.Sprintf("expected failed signature: %q", err.Error())
t.Fatal(m)
for algo, error := range map[string]string{
AlgorithmNkey: "claim failed V2 signature verification",
} {
h := Header{TokenTypeJwt, algo}
c := NewGenericClaims(publicKey(createUserNKey(t), t))
c.Data["foo"] = "bar"

token, err := c.doEncode(&h, kp, c)
if err != nil {
t.Fatal(err)
}

token = token + "A"

claim, err := DecodeGeneric(token)
if claim != nil {
t.Fatal("non nil claim on bad token")
}

if err == nil {
t.Fatal("nil error on bad token")
}

if err.Error() != error {
m := fmt.Sprintf("expected failed signature: %q", err.Error())
t.Fatal(m)
}
}
}

Expand Down Expand Up @@ -201,7 +198,7 @@ func TestDifferentPayload(t *testing.T) {
t.Fatal("nil error on bad token")
}

if err.Error() != "claim failed signature verification" {
if err.Error() != "claim failed V2 signature verification" {
m := fmt.Sprintf("expected failed signature: %q", err.Error())
t.Fatal(m)
}
Expand Down Expand Up @@ -332,7 +329,7 @@ func TestBadClaimsJSON(t *testing.T) {
func TestBadPublicKeyDecodeGeneric(t *testing.T) {
c := &GenericClaims{}
c.Issuer = "foo"
if ok := c.Verify("foo", []byte("bar")); ok {
if ok := c.verify("foo", []byte("bar")); ok {
t.Fatal("Should have failed to verify")
}
}
Expand Down
26 changes: 15 additions & 11 deletions v2/genericlaims.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func DecodeGeneric(token string) (*GenericClaims, error) {
}

// header
if _, err := parseHeaders(chunks[0]); err != nil {
header, err := parseHeaders(chunks[0])
if err != nil {
return nil, err
}
// claim
Expand All @@ -68,10 +69,16 @@ func DecodeGeneric(token string) (*GenericClaims, error) {
if err != nil {
return nil, err
}
if !gc.Verify(chunks[1], sig) {
return nil, errors.New("claim failed signature verification")
}

if header.Algorithm == AlgorithmNkeyOld {
if !gc.verify(chunks[1], sig) {
return nil, errors.New("claim failed V1 signature verification")
}
} else {
if !gc.verify(token[:len(chunks[0])+len(chunks[1])+1], sig) {
return nil, errors.New("claim failed V2 signature verification")
}
}
return &gc, nil
}

Expand All @@ -87,7 +94,7 @@ func (gc *GenericClaims) Payload() interface{} {

// Encode takes a generic claims and creates a JWT string
func (gc *GenericClaims) Encode(pair nkeys.KeyPair) (string, error) {
return gc.ClaimsData.Encode(pair, gc)
return gc.ClaimsData.encode(pair, gc)
}

// Validate checks the generic part of the claims data
Expand Down Expand Up @@ -123,11 +130,8 @@ func (gc *GenericClaims) ClaimType() ClaimType {
}

func (gc *GenericClaims) updateVersion() {
v, ok := gc.Data["nats"]
if ok {
m, ok := v.(map[string]interface{})
if ok {
m["version"] = libVersion
}
if gc.Data != nil {
// store as float as that is what decoding with json does too
gc.Data["version"] = float64(libVersion)
}
}
6 changes: 4 additions & 2 deletions v2/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ const (

// TokenTypeJwt is the JWT token type supported JWT tokens
// encoded and decoded by this library
TokenTypeJwt = "jwt"
// from RFC7519 5.1 "typ":
// it is RECOMMENDED that "JWT" always be spelled using uppercase characters for compatibility
TokenTypeJwt = "JWT"

// AlgorithmNkey is the algorithm supported by JWT tokens
// encoded and decoded by this library
Expand Down Expand Up @@ -61,7 +63,7 @@ func parseHeaders(s string) (*Header, error) {
// Valid validates the Header. It returns nil if the Header is
// a JWT header, and the algorithm used is the NKEY algorithm.
func (h *Header) Valid() error {
if TokenTypeJwt != strings.ToLower(h.Type) {
if TokenTypeJwt != strings.ToUpper(h.Type) {
return fmt.Errorf("not supported type %q", h.Type)
}

Expand Down
2 changes: 1 addition & 1 deletion v2/operator_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (oc *OperatorClaims) Encode(pair nkeys.KeyPair) (string, error) {
return "", err
}
oc.Type = OperatorClaim
return oc.ClaimsData.Encode(pair, oc)
return oc.ClaimsData.encode(pair, oc)
}

func (oc *OperatorClaims) ClaimType() ClaimType {
Expand Down
2 changes: 1 addition & 1 deletion v2/user_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (u *UserClaims) Encode(pair nkeys.KeyPair) (string, error) {
return "", errors.New("expected subject to be user public key")
}
u.Type = UserClaim
return u.ClaimsData.Encode(pair, u)
return u.ClaimsData.encode(pair, u)
}

// DecodeUserClaims tries to parse a user claims from a JWT string
Expand Down

0 comments on commit b9df3db

Please sign in to comment.