diff --git a/algorithms.go b/algorithms.go index ec2c434..cce6615 100644 --- a/algorithms.go +++ b/algorithms.go @@ -72,11 +72,10 @@ func (v nonevalidator) sign(jwt *jwt) error { return nil } -func (jwt *jwt) rawEncode() { +func (jwt *jwt) rawEncode() (header, payload []byte) { headerBuf := bytes.NewBuffer(nil) payloadBuf := bytes.NewBuffer(nil) - // TODO: Determine if errors here are possible/relevant json.NewEncoder(headerBuf).Encode(jwt.Header) json.NewEncoder(payloadBuf).Encode(jwt.Payload) @@ -86,12 +85,14 @@ func (jwt *jwt) rawEncode() { json.Compact(compactHeaderBuf, headerBuf.Bytes()) json.Compact(compactPayloadBuf, payloadBuf.Bytes()) - jwt.headerRaw = make([]byte, base64.URLEncoding.EncodedLen(len(compactHeaderBuf.Bytes()))) - jwt.payloadRaw = make([]byte, base64.URLEncoding.EncodedLen(len(compactPayloadBuf.Bytes()))) + header = make([]byte, base64.URLEncoding.EncodedLen(len(compactHeaderBuf.Bytes()))) + payload = make([]byte, base64.URLEncoding.EncodedLen(len(compactPayloadBuf.Bytes()))) - base64.URLEncoding.Encode(jwt.headerRaw, compactHeaderBuf.Bytes()) - base64.URLEncoding.Encode(jwt.payloadRaw, compactPayloadBuf.Bytes()) + base64.URLEncoding.Encode(header, compactHeaderBuf.Bytes()) + base64.URLEncoding.Encode(payload, compactPayloadBuf.Bytes()) - jwt.headerRaw = []byte(strings.Trim(string(jwt.headerRaw), "=")) - jwt.payloadRaw = []byte(strings.Trim(string(jwt.payloadRaw), "=")) + header = []byte(strings.Trim(string(header), "=")) + payload = []byte(strings.Trim(string(payload), "=")) + + return header, payload } diff --git a/es.go b/es.go index 932e5a6..71637dc 100644 --- a/es.go +++ b/es.go @@ -60,11 +60,11 @@ func (v ESValidator) sign(jwt *jwt) (err error) { } jwt.Header.Algorithm = v.algorithm - jwt.rawEncode() + header, payload := jwt.rawEncode() // TODO: This block is general. Refactor it out of RS and ES validators hsh := v.hashType.New() - hsh.Write([]byte(string(jwt.headerRaw) + "." + string(jwt.payloadRaw))) + hsh.Write([]byte(string(header) + "." + string(payload))) hash := hsh.Sum(nil) r, s, err := ecdsa.Sign(v.rand, v.PrivateKey, hash) diff --git a/hs.go b/hs.go index 33bf0d7..42c6937 100644 --- a/hs.go +++ b/hs.go @@ -69,10 +69,10 @@ func (v hsValidator) validate(jwt *jwt) (bool, error) { func (v hsValidator) sign(jwt *jwt) error { jwt.Header.Algorithm = v.algorithm - jwt.rawEncode() + header, payload := jwt.rawEncode() mac := hmac.New(v.hashFunc, v.Key) - mac.Write([]byte(strings.Trim(string(jwt.headerRaw), "=") + "." + strings.Trim(string(jwt.payloadRaw), "="))) + mac.Write([]byte(string(header) + "." + string(payload))) jwt.Signature = []byte(base64.URLEncoding.EncodeToString(mac.Sum(nil))) return nil diff --git a/jwt.go b/jwt.go index 0cb4db5..3a1ac6e 100644 --- a/jwt.go +++ b/jwt.go @@ -44,7 +44,6 @@ type Payload struct { NotBefore *time.Time `json:"nbf,omitempty"` IssuedAt *time.Time `json:"iat,omitempty"` JWTId string `json:"jti,omitempty"` - raw []byte } // A Decoder is a centeralized reader and key used to consume and verify a @@ -66,19 +65,16 @@ type Encoder struct { type header struct { Algorithm Algorithm `json:"alg"` ContentType string `json:"typ"` - raw []byte } // A jwt is a unified structure of the components of a jwt. This structure is //used internally to aggregate components during encoding and decoding. type jwt struct { - Header *header - headerRaw []byte - Payload interface{} - claimsPayload *Payload - payloadRaw []byte - registeredPayload Payload - Signature []byte + Header *header + headerRaw []byte + Payload interface{} + payloadRaw []byte + Signature []byte } // NewDecoder creates an underlying Decoder with a given key and input reader @@ -91,7 +87,8 @@ func NewDecoder(r io.Reader, v Validator) *Decoder { // of the given token is verified and will return an error if a bad signature is // found. In addition if the jwt is using an unimplemented algorithm an error will // be returned as well. -func (dec *Decoder) Decode(v interface{}) error { +func (dec *Decoder) Decode(v interface{}) (err error) { + var valid bool buf := bufio.NewReader(dec.reader) input, err := buf.ReadString(byte(' ')) @@ -102,16 +99,16 @@ func (dec *Decoder) Decode(v interface{}) error { return err } - if valid, err := dec.validator.validate(jwt); !valid || err != nil { + if valid, err = dec.validator.validate(jwt); !valid || err != nil { if err != nil { return err } - return ErrBadSignature + err = ErrBadSignature } - return nil + return err } // NewEncoder creates an underlying Encoder with a given key and output writer @@ -135,7 +132,9 @@ func (enc *Encoder) Encode(v interface{}) error { return err } - fmt.Fprintf(enc.writer, "%s", jwt.token()) + header, payload := jwt.rawEncode() + + fmt.Fprintf(enc.writer, "%s.%s.%s", string(header), string(payload), string(jwt.Signature)) return nil } @@ -144,7 +143,7 @@ func (jwt *jwt) parseHeader(raw string) error { var err error var value []byte - if value, err = parseField(raw); err != nil { + if value, err = padB64String(raw); err != nil { return err } @@ -160,8 +159,7 @@ func (jwt *jwt) parseHeader(raw string) error { func parseJWT(input string, payload interface{}) (*jwt, error) { var err error jwt := &jwt{ - Header: &header{}, - claimsPayload: &Payload{}, + Header: &header{}, } fields := strings.Split(input, ".") @@ -183,35 +181,18 @@ func parseJWT(input string, payload interface{}) (*jwt, error) { return jwt, nil } -func (jwt *jwt) token() string { - header := strings.Trim(string(jwt.headerRaw), "=") - payload := strings.Trim(string(jwt.payloadRaw), "=") - signature := strings.Trim(string(jwt.Signature), "=") - - return fmt.Sprintf("%s.%s.%s", header, payload, signature) -} - -func (jwt *jwt) parsePayload(raw string, v interface{}) error { +func (jwt *jwt) parsePayload(raw string, v interface{}) (err error) { jwt.payloadRaw = []byte(raw) - value, err := parseField(raw) + value, err := padB64String(raw) if err != nil { return err } - // TODO: How to deal with json encoder errors? - err = json.NewDecoder(bytes.NewReader(value)).Decode(v) - - if err != nil { - return err - } - - json.NewDecoder(bytes.NewReader(value)).Decode(jwt.claimsPayload) - - return nil + return json.NewDecoder(bytes.NewReader(value)).Decode(v) } -func parseField(b64Value string) ([]byte, error) { +func padB64String(b64Value string) ([]byte, error) { if m := len(b64Value) % 4; m != 0 { b64Value += strings.Repeat("=", 4-m) } diff --git a/jwt_test.go b/jwt_test.go index e938f2e..1a27839 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -133,9 +133,9 @@ func TestEncodingSigning(t *testing.T) { Payload interface{} ValidToken string }{ - {HS256, struct{}{}, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.UGgJ_8f7TlqazSojqRAKzMJ0SUWJCJJ_9jDHe5nrhto"}, + {HS256, struct{}{}, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.UGgJ_8f7TlqazSojqRAKzMJ0SUWJCJJ_9jDHe5nrhto="}, {HS384, struct{}{}, "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.e30.YGfeZ7CN9vKz4M2SINxTixlpUEDqsCZNx4LMJK62Lr_Eiptnikcf5XfgDd_7eVWe"}, - {HS512, struct{}{}, "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.e30.wHUM-6oRBExIgOk9MLOQ_80WqbuOmXXNuyTy4WmM_0WBM6pXld0mru8rZbc9-E314K9UhMkDNHbg2MRjIsCR3g"}, + {HS512, struct{}{}, "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.e30.wHUM-6oRBExIgOk9MLOQ_80WqbuOmXXNuyTy4WmM_0WBM6pXld0mru8rZbc9-E314K9UhMkDNHbg2MRjIsCR3g=="}, } for _, c := range cases { @@ -178,11 +178,11 @@ func ExampleEncoder() { } fmt.Println(tokenBuffer.String()) - // Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI + // Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI= } func ExampleDecoder() { - token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI" + token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJCZW4gQ2FtcGJlbGwiLCJhZG1pbiI6dHJ1ZSwidXNlcl9pZCI6MTIzNH0.r4W8qDl8i8cUcRUxtA3hM0SZsLScHiBgBKZc_n_GrXI=" payload := &struct { Payload @@ -200,5 +200,5 @@ func ExampleDecoder() { } fmt.Printf("%+v\n", payload) - // Output: &{Payload:{Issuer:Ben Campbell Subject: Audience: ExpirationTime: NotBefore: IssuedAt: JWTId: raw:[]} Admin:true UserID:1234} + // Output: &{Payload:{Issuer:Ben Campbell Subject: Audience: ExpirationTime: NotBefore: IssuedAt: JWTId:} Admin:true UserID:1234} } diff --git a/rs.go b/rs.go index 92ad988..0e888b9 100644 --- a/rs.go +++ b/rs.go @@ -58,16 +58,16 @@ func (v RSValidator) validate(jwt *jwt) (bool, error) { } jwt.Header.Algorithm = v.algorithm - jwt.rawEncode() + header, payload := jwt.rawEncode() signature, err := base64.URLEncoding.DecodeString(string(jwt.Signature)) if err != nil { - return false, err + return false, ErrMalformedToken } hsh := v.hashType.New() - hsh.Write([]byte(string(jwt.headerRaw) + "." + string(jwt.payloadRaw))) + hsh.Write([]byte(string(header) + "." + string(payload))) hash := hsh.Sum(nil) err = rsa.VerifyPKCS1v15(v.PublicKey, v.hashType, hash, signature) @@ -81,10 +81,10 @@ func (v RSValidator) validate(jwt *jwt) (bool, error) { func (v RSValidator) sign(jwt *jwt) (err error) { jwt.Header.Algorithm = v.algorithm - jwt.rawEncode() + header, payload := jwt.rawEncode() hsh := v.hashType.New() - hsh.Write([]byte(string(jwt.headerRaw) + "." + string(jwt.payloadRaw))) + hsh.Write([]byte(string(header) + "." + string(payload))) hash := hsh.Sum(nil) signature, _ := rsa.SignPKCS1v15(v.randReader, v.PrivateKey, v.hashType, hash)